機器學習(4)--層次聚類(hierarchical clustering)基本原理及實現簡單圖片分類

關於層次聚類(hierarchical clustering)的基本步驟:
1、假設每個樣本爲一類,計算每個類的距離,也就是相似度
2、把最近的兩個合爲一新類,這樣類別數量就少了一個
3、重新新類與各個舊類(去了那兩個合併的類)之間的相似度;
4、循環重複2和3直到所有樣本點都歸爲一類

這個計算的過程,相當於重構一個二叉樹,只是這個過程,是從樹葉-->樹枝-->樹幹的構建過程

本例將以14張圖片,做爲樣本,進行聚類,點擊這裏  下載圖片樣本

以下是使用我提供的圖片庫生成的分類結果,以及一張PS修後對代碼中各變量的說明



當然,你也可以自己定義一個目錄,程序會讀取目錄下所有JPG圖片

如果你用了自己的圖片,在代碼中的一此數據的變化說明,可能和使用產生的數據不同了,

同時,本文的主要目的是層次聚類(hierarchical clustering)的基本步驟,對於圖片的相似度的算法並不完善,效果也並不是十分理想,不過如果你使用自己從手機中導入的生活照,不同的場景大致還是能分類出來的

# -*- coding:utf-8 -*-

from PIL import ImageDraw,Image
import numpy as np
import os
import sys


nodeList = []#用於存儲所有的節點,包含圖片節點,與聚類後的節點
distance = {}#用於存儲所有每兩個節點的距離,數據格式{(node1.id,node2.id):30.0,(node2.id,node3.id):40.0}
class node:
    def __init__(self, data):
        '''每個樣本及樣本合併後節點的類
            data:接受兩種格式,
            1、當爲字符(string)時,是圖片的地址,同時也表示這個節點就是圖片
            2、合併後的類,傳入的格式爲(leftNode,rightNode) 即當前類表示合併後的新類,而對應的左右節點就是子節點
        '''
        self.id = len(nodeList)#設置一個ID,以nodeList當然長度爲ID,在本例中ID本身沒太大用處,只是如果看代碼時,有時要看指向時有點用
        self.parent = None # 指向合併後的類
        self.pos = None#用於最後繪製節構圖使用,賦值時爲(x,y,w,h)格式
        if type(data) == type("") :
            '''節點爲圖片'''
            self.imgData = Image.open(data)
            self.left = None
            self.right = None 
            self.level = 0    #圖片爲最終的子節點,所有圖片的層級都爲0,設置層級是爲了最終繪製結構圖

            npTmp = np.array(self.imgData).reshape(-1,3) #將圖片數據轉化爲numpy數據,shape爲(高,寬,3),3爲顏色通道
            npTmp = npTmp.reshape(-1,3)  #重新排列,shape爲(高*寬,3)
            self.feature = npTmp.mean(axis=0)#計算RGB三個顏色通道均值

        else:
            '''節點爲合成的新類'''
            self.imgData = None
            self.left = data[0]
            self.right = data[1]
            self.left.parent = self
            self.right.parent = self

            self.level = max(self.left.level,self.right.level) + 1 #層級爲左右節高層級的級數+1
            self.feature = (self.left.feature + self.right.feature) / 2 #兩類的合成一類時,就是左右節點的feature相加/2
            
        #計算該類與每個其他類的距離,並存入distance
        for x in nodeList:
            distance[(x,self)] = np.sqrt(np.sum((x.feature - self.feature) ** 2))

        nodeList.append(self)#將本類加入nodeList變量

    def drawNode(self,img,draw,vLineLenght):
        #繪製結構圖
        if self.pos == None:return
        if self.left == None:
            #如果是圖片
            self.imgData.thumbnail((self.pos[2], self.pos[3]))
            img.paste(self.imgData,(self.pos[0], self.pos[1]))
            draw.line((int(self.pos[0] + self.pos[2] / 2)
                 , self.pos[1] - vLineLenght
                 , int(self.pos[0] + self.pos[2] / 2)
                 , self.pos[1])
                , fill=(255, 0, 0))
        else:
            #如果不是圖片
            draw.line((int(self.pos[0])
                 , self.pos[1]
                 , int(self.pos[0] + self.pos[2])
                 , self.pos[1])
                , fill=(255, 0, 0))

            draw.line((int(self.pos[0] + self.pos[2] / 2)
                    , self.pos[1]
                    , int(self.pos[0] + self.pos[2] / 2)
                    , self.pos[1] - self.pos[3])
                    , fill=(255, 0, 0))

def loadImg(path):
    '''path 圖片目錄,根據自己存的地方改寫'''
    files = None
    try :
        files = os.listdir(path)
    except:
        print('未正確讀取目錄:' + path + ',圖片目錄,請根據自己存的地方改寫,並保證沒有hierarchicalResult.jpg,該文件爲最後生成文件')
        return None
    for i in files:

        if os.path.splitext(i)[1].lower() == '.jpg' and os.path.splitext(i)[0].lower() != 'hierarchicalresult':

            fileName = os.path.join(path,i)
            node(fileName)
    return os.path.join(path,'hierarchicalResult.jpg')

def getMinDistance():
    '''從distance中過濾出未分類的結點,並讀取最小的距離'''
    vars = list(filter(lambda x:x[0].parent == None and x[1].parent == None ,distance))
    minDist = vars[0]
    for x in vars:
        if minDist == None or distance[x] < distance[minDist]:
            minDist = x
    return minDist

def createTree():
    while len(list(filter(lambda x:x.parent == None ,nodeList))) > 1:#合併到最後時,只有一個類,只要有兩個以上未合併,就循環
        minDist = getMinDistance()
        #創建非圖片的節點,之所以把[1]做爲左節點,因爲繪圖時的需要,
        #在不斷的產生非圖片節點時,在nodeList的後面的一般是新節點,但繪圖時繪在左邊
        node((minDist[1],minDist[0])) 
    return nodeList[-1]#最後一個插入的節點就是要節點


def run():
    root = createTree()#創建樹結構

    #一句話的PYTON,實現二叉樹的左右根遍歷,通過通過遍歷,進行排序後,取出圖片,做爲最底層的打印
    sortTree = lambda node:([] if node.left == None else sortTree(node.left)) + ([] if node.right == None else sortTree(node.right)) + [node]
    treeTmp = sortTree(root)
    treeTmp = list(filter(lambda x:x.left == None,treeTmp))#沒有左節點的,即爲圖片

    thumbSize = 60 #縮略圖的大小,,在60X60的小格內縮放
    thumbSpace = 20 #縮略圖間距
    vLineLenght = 80 #上下節點,即每個level之間的高度

    imgWidth = len(treeTmp) * (thumbSize + thumbSpace)
    imgHeight = (root.level+1) * vLineLenght + thumbSize + thumbSpace*2
    img = Image.new('RGB', (imgWidth,imgHeight), (255, 255, 255))
    draw = ImageDraw.Draw(img)

    for item in enumerate(treeTmp):
        #爲所有圖片增加繪圖數據
        x = item[0] * (thumbSize + thumbSpace) + thumbSpace / 2
        y = imgHeight - thumbSize - thumbSpace / 2 - ((item[1].parent.level - 1) * vLineLenght)
        w = item[1].imgData.width
        h = item[1].imgData.height
        if w > h:
            h = h / w * thumbSize
            w = thumbSize
        else:
            w = w / h * thumbSize
            h = thumbSize
            x+=(thumbSize - w) / 2
        item[1].pos = (int(x),int(y),int(w),int(h))
        item[1].drawNode(img,draw,vLineLenght)

    for x in range(1,root.level + 1):
        #爲所有非圖片增加繪圖的數據
        items = list(filter(lambda i:i.level == x,nodeList))
        for item in items:
            x = item.left.pos[0] + (item.left.pos[2] / 2)
            w = item.right.pos[0] + (item.right.pos[2] / 2) - x
            y = item.left.pos[1] - (item.level - item.left.level) * vLineLenght
            h = ((item.parent.level if item.parent != None else item.level + 1) - item.level) * vLineLenght
            item.pos = (int(x),int(y),int(w),int(h))
            item.drawNode(img,draw,vLineLenght)
    img.save(resultFile)

resultFile = loadImg(r"E:\hierarchicalImgs")#讀取數據,並返回最後結果要存儲的文件名,目錄根據自己存的位置進行修改
if resultFile != 'None':
    run()
    print("結構圖生成成功,最終結構圖存儲於:" + resultFile)



發佈了46 篇原創文章 · 獲贊 14 · 訪問量 2萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章