監督學習——決策樹理論與實踐(下):迴歸決策樹(CART)

介紹

決策樹分爲分類決策樹和迴歸決策樹:

bacd

上一篇介紹了分類決策樹以及Python實現分類決策樹: 監督學習——決策樹理論與實踐(上):分類決策樹

         決策樹是一種依託決策而建立起來的一種樹。在機器學習中,決策樹是一種預測模型,代表的是一種對象屬性與對象值之間的一種映射關係,每一個節點代表某個對象/分類,樹中的每一個分叉路徑代表某個可能的屬性值,而每一個葉子節點則對應從根節點到該葉子節點所經歷的路徑所表示的對象的值

        通過訓練數據構建決策樹,可以高效的對未知的數據進行分類。決策數有兩大優點:1)決策樹模型可以讀性好,具有描述性,有助於人工分析;2)效率高,決策樹只需要一次構建,反覆使用,每一次預測的最大計算次數不超過決策樹的深度。

      決策樹是一顆樹形的數據結構,可以是多叉樹也可以是二叉樹,決策樹實際上是一種基於貪心策略構造的,每次選擇的都是最優的屬性進行分裂。

      決策樹也是一種監督學習算法,它的樣本是(x,y)形式的輸入輸出樣例。

  迴歸樹:

         相對於上一篇所講的決策樹,這篇所講的迴歸樹主要解決迴歸問題,所以給定的訓練數據輸入和標籤都是連續的。


CART迴歸樹生成算法

決策樹的生成

        CART算法的思路是將特徵空間切分爲m個不同的子空間,通過測試數據(落在每個子空間中的測試數據)來計算每個子空間的輸出值(對應下式中的Cm)。當這樣的空間幾何生成之後就可以很方便的將一個未知數據映射到某一個子空間Ri中,將Ci的值作爲該未知數據的輸出值。

image

這裏Cm的取值一般採用均值算法,即取所有落在該子空間的測試數據的均作爲該子空間的值:

image

這裏肯定會涉及到一個,這也是CART算法的關鍵: 如何去劃分一個一個子空間?如何去選擇第j個變量Xj和它取值s作爲切分變量和切分點,並定義成兩個區域。這裏《統計學方法》中給出了算法思路:

image

算法實現時,比那裏所有切分向量,切分點是測試數據在Xj上的所有取值集合。通過5.19就能計算出當前最佳的切分向量j和切分點x以及劃分成的兩個區域的取值c1,c2。(該部分的Python實現對應下文中chooseBestSplit函數)

      當對一個整體測試數據調用上面邏輯後會得到一個j和x值,通過這兩個值將空間分成了兩個空間,再分別對兩個子空間調用上面的邏輯,這樣遞歸下去就能生成一棵決策樹。(對應下文中createTree函數

決策樹的剪枝

CART剪枝算法從“完全生長”的決策樹的底端減去一些子樹,使決策樹變小(模型變簡單),從而能夠對未知數據有更準確的預測。

後續待補充


CART算法Python實現

數據加載

加載測試數據,以及測試數據的值(X,Y),這裏數據和值都存放在一個矩陣中。

def loadDataSet(fileName):      #general function to parse tab -delimited floats
    dataMat = []                #assume last column is target value
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = map(float,curLine) #map all elements to float()
        dataMat.append(fltLine)
    return dataMat

數據劃分

該函數用於切分數據集,將測試數據某一列中的元素大於和小於的測試數據分開,分別放到兩個矩陣中:

def binSplitDataSet(dataSet, feature, value):
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
    return mat0,mat1

輸入參數 feature 爲指定的某一列

value爲切分點的值,通過該該值將dataset一份爲二

尋找最優切分特徵以及切分點

這裏涉及到三個函數,分別在代碼註釋中進行了說明,真正計算最優值的函數爲最後一個。

# 葉節點值計算函數: 這裏以均值作爲葉節點值
def regLeaf(dataSet):#returns the value used for each leaf
    return mean(dataSet[:,-1])

# 預測誤差計算函數:這裏用均方差表示
def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]

# 遍歷每一列中每個value值,找到最適合分裂的列和切分點
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0];   # 均方差最小優化值,如果大於該值則沒有必要切分
    tolN = ops[1]    # 需要切分數據的最小長度,如果已經小於該值,則無需再切分
    #if all the target variables are the same value: quit and return value
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1
        return None, leafType(dataSet)
    m,n = shape(dataSet)
    #the choice of the best feature is driven by Reduction in RSS error from mean
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set(dataSet[:,featIndex]):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    #if the decrease (S-bestS) is less than a threshold don't do the split
    if (S - bestS) < tolS:
        return None, leafType(dataSet) #exit cond 2
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  #exit cond 3
        return None, leafType(dataSet)
    return bestIndex,bestValue#returns the best feature to split on
                              #and the value used for that split

迴歸樹的創建

        在上面函數基礎之上,創建一個迴歸樹也就不難了:

def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split
    if feat == None: return val #if the splitting hit a stop condition return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

這裏是一個遞歸調用,需要注意函數的終止條件,這裏當數據集不能再分時纔會觸發終止條件,實際中這種操作很有可能會出現過擬合,可以認爲地加一些終止條件進行“預剪枝”

 

參考:

《機器學習實戰》

《統計學習方法》

https://blog.csdn.net/u014568921/article/details/45082197

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章