1、假設每個樣本爲一類,計算每個類的距離,也就是相似度
2、把最近的兩個合爲一新類,這樣類別數量就少了一個
3、重新新類與各個舊類(去了那兩個合併的類)之間的相似度;
4、循環重複2和3直到所有樣本點都歸爲一類
這個計算的過程,相當於重構一個二叉樹,只是這個過程,是從樹葉-->樹枝-->樹幹的構建過程
本例將以14張圖片,做爲樣本,進行聚類,點擊這裏 下載圖片樣本
以下是使用我提供的圖片庫生成的分類結果,以及一張PS修後對代碼中各變量的說明
當然,你也可以自己定義一個目錄,程序會讀取目錄下所有JPG圖片
如果你用了自己的圖片,在代碼中的一此數據的變化說明,可能和使用產生的數據不同了,同時,本文的主要目的是層次聚類(hierarchical clustering)的基本步驟,對於圖片的相似度的算法並不完善,效果也並不是十分理想,不過如果你使用自己從手機中導入的生活照,不同的場景大致還是能分類出來的
# -*- coding:utf-8 -*-
from PIL import ImageDraw,Image
import numpy as np
import os
import sys
nodeList = []#用於存儲所有的節點,包含圖片節點,與聚類後的節點
distance = {}#用於存儲所有每兩個節點的距離,數據格式{(node1.id,node2.id):30.0,(node2.id,node3.id):40.0}
class node:
def __init__(self, data):
'''每個樣本及樣本合併後節點的類
data:接受兩種格式,
1、當爲字符(string)時,是圖片的地址,同時也表示這個節點就是圖片
2、合併後的類,傳入的格式爲(leftNode,rightNode) 即當前類表示合併後的新類,而對應的左右節點就是子節點
'''
self.id = len(nodeList)#設置一個ID,以nodeList當然長度爲ID,在本例中ID本身沒太大用處,只是如果看代碼時,有時要看指向時有點用
self.parent = None # 指向合併後的類
self.pos = None#用於最後繪製節構圖使用,賦值時爲(x,y,w,h)格式
if type(data) == type("") :
'''節點爲圖片'''
self.imgData = Image.open(data)
self.left = None
self.right = None
self.level = 0 #圖片爲最終的子節點,所有圖片的層級都爲0,設置層級是爲了最終繪製結構圖
npTmp = np.array(self.imgData).reshape(-1,3) #將圖片數據轉化爲numpy數據,shape爲(高,寬,3),3爲顏色通道
npTmp = npTmp.reshape(-1,3) #重新排列,shape爲(高*寬,3)
self.feature = npTmp.mean(axis=0)#計算RGB三個顏色通道均值
else:
'''節點爲合成的新類'''
self.imgData = None
self.left = data[0]
self.right = data[1]
self.left.parent = self
self.right.parent = self
self.level = max(self.left.level,self.right.level) + 1 #層級爲左右節高層級的級數+1
self.feature = (self.left.feature + self.right.feature) / 2 #兩類的合成一類時,就是左右節點的feature相加/2
#計算該類與每個其他類的距離,並存入distance
for x in nodeList:
distance[(x,self)] = np.sqrt(np.sum((x.feature - self.feature) ** 2))
nodeList.append(self)#將本類加入nodeList變量
def drawNode(self,img,draw,vLineLenght):
#繪製結構圖
if self.pos == None:return
if self.left == None:
#如果是圖片
self.imgData.thumbnail((self.pos[2], self.pos[3]))
img.paste(self.imgData,(self.pos[0], self.pos[1]))
draw.line((int(self.pos[0] + self.pos[2] / 2)
, self.pos[1] - vLineLenght
, int(self.pos[0] + self.pos[2] / 2)
, self.pos[1])
, fill=(255, 0, 0))
else:
#如果不是圖片
draw.line((int(self.pos[0])
, self.pos[1]
, int(self.pos[0] + self.pos[2])
, self.pos[1])
, fill=(255, 0, 0))
draw.line((int(self.pos[0] + self.pos[2] / 2)
, self.pos[1]
, int(self.pos[0] + self.pos[2] / 2)
, self.pos[1] - self.pos[3])
, fill=(255, 0, 0))
def loadImg(path):
'''path 圖片目錄,根據自己存的地方改寫'''
files = None
try :
files = os.listdir(path)
except:
print('未正確讀取目錄:' + path + ',圖片目錄,請根據自己存的地方改寫,並保證沒有hierarchicalResult.jpg,該文件爲最後生成文件')
return None
for i in files:
if os.path.splitext(i)[1].lower() == '.jpg' and os.path.splitext(i)[0].lower() != 'hierarchicalresult':
fileName = os.path.join(path,i)
node(fileName)
return os.path.join(path,'hierarchicalResult.jpg')
def getMinDistance():
'''從distance中過濾出未分類的結點,並讀取最小的距離'''
vars = list(filter(lambda x:x[0].parent == None and x[1].parent == None ,distance))
minDist = vars[0]
for x in vars:
if minDist == None or distance[x] < distance[minDist]:
minDist = x
return minDist
def createTree():
while len(list(filter(lambda x:x.parent == None ,nodeList))) > 1:#合併到最後時,只有一個類,只要有兩個以上未合併,就循環
minDist = getMinDistance()
#創建非圖片的節點,之所以把[1]做爲左節點,因爲繪圖時的需要,
#在不斷的產生非圖片節點時,在nodeList的後面的一般是新節點,但繪圖時繪在左邊
node((minDist[1],minDist[0]))
return nodeList[-1]#最後一個插入的節點就是要節點
def run():
root = createTree()#創建樹結構
#一句話的PYTON,實現二叉樹的左右根遍歷,通過通過遍歷,進行排序後,取出圖片,做爲最底層的打印
sortTree = lambda node:([] if node.left == None else sortTree(node.left)) + ([] if node.right == None else sortTree(node.right)) + [node]
treeTmp = sortTree(root)
treeTmp = list(filter(lambda x:x.left == None,treeTmp))#沒有左節點的,即爲圖片
thumbSize = 60 #縮略圖的大小,,在60X60的小格內縮放
thumbSpace = 20 #縮略圖間距
vLineLenght = 80 #上下節點,即每個level之間的高度
imgWidth = len(treeTmp) * (thumbSize + thumbSpace)
imgHeight = (root.level+1) * vLineLenght + thumbSize + thumbSpace*2
img = Image.new('RGB', (imgWidth,imgHeight), (255, 255, 255))
draw = ImageDraw.Draw(img)
for item in enumerate(treeTmp):
#爲所有圖片增加繪圖數據
x = item[0] * (thumbSize + thumbSpace) + thumbSpace / 2
y = imgHeight - thumbSize - thumbSpace / 2 - ((item[1].parent.level - 1) * vLineLenght)
w = item[1].imgData.width
h = item[1].imgData.height
if w > h:
h = h / w * thumbSize
w = thumbSize
else:
w = w / h * thumbSize
h = thumbSize
x+=(thumbSize - w) / 2
item[1].pos = (int(x),int(y),int(w),int(h))
item[1].drawNode(img,draw,vLineLenght)
for x in range(1,root.level + 1):
#爲所有非圖片增加繪圖的數據
items = list(filter(lambda i:i.level == x,nodeList))
for item in items:
x = item.left.pos[0] + (item.left.pos[2] / 2)
w = item.right.pos[0] + (item.right.pos[2] / 2) - x
y = item.left.pos[1] - (item.level - item.left.level) * vLineLenght
h = ((item.parent.level if item.parent != None else item.level + 1) - item.level) * vLineLenght
item.pos = (int(x),int(y),int(w),int(h))
item.drawNode(img,draw,vLineLenght)
img.save(resultFile)
resultFile = loadImg(r"E:\hierarchicalImgs")#讀取數據,並返回最後結果要存儲的文件名,目錄根據自己存的位置進行修改
if resultFile != 'None':
run()
print("結構圖生成成功,最終結構圖存儲於:" + resultFile)