機器學習實戰之樹迴歸(CART)python實現(附python3代碼)

樹迴歸
CART(Classification And Regression Tree, 分類迴歸樹)

完整代碼github
環境 python3

決策樹分類

決策樹不斷將數據切分成小數據集,直到所有的目標變量完全相同,或者數據不能再切分爲止。
決策樹是一種貪心算法,它要在給定的時間內做出最佳選擇,但不關心是否達到全局最優。

ID3的做法是每次選取當前最佳的特徵來分割數據,並按照該特徵的所有取值進行切分,一旦按某種特徵進行切分後,該特徵在之後的算法執行過程中就不在起作用。這種切分方式過於迅速且不能處理連續型特徵,只有事先將連續型特徵轉換爲離散型特徵,才能在ID3算法中使用,但是這種轉換會破壞連續型變量的內在性質。
詳細介紹可以參照決策樹python實現(ID3 和 C4.5)

樹迴歸

優點:可以對複雜和非線性的數據建模
缺點:結果不易理解
適用數據類型:數值型和標稱型數據

CART使用二元切分法來處理連續型變量,對CART稍作修改就可以處理迴歸問題。
二元切分法:每次把數據集切分成兩份,如果數據的某特徵值大於切分所要求的值,那麼這些數據進入樹的左子樹,反之則進入樹的右子樹。二元切分法節省了樹的構建時間,這點意義不大,因爲樹的構建一般是離線完成。

樹迴歸的一般方法
(1)收集數據:任意方法
(2)準備數據:需要數值型數據,標稱型數據應該映射成二值型數據
(3)分析數據:繪出數據的二維可視化顯示結果,以字典方式生成樹
(4)訓練算法:大部分時間都花在葉節點樹模型的構建上
(5)測試算法:使用測試數據上的R^2值來分析模型的效果
(6)使用算法:使用訓練出的樹做預測,預測結果可以用來做許多事情

連續和離散型特徵的樹構建

用字典來存儲樹的數據結構,該字典包含以下4元素:

  • 待切分的特徵
  • 待切分的特徵值
  • 右子樹,當不再需要切分的時候,也可以是單個值
  • 左子樹,與右子樹類似

ID3用一部字典來存儲每個切分,該字典可以包含兩個或兩個以上的值。
CART算法只做二元切分,所以這裏可以固定樹的數據結構。樹包含左鍵和右鍵,可以存儲另一棵子樹或者單個值;字典還包含特徵和特徵值這兩個鍵,它們給出切分算法所有的特徵和特徵值。
本章構建兩種樹:迴歸樹和模型樹
迴歸樹:其每個節點包含單個值
模型樹:其每個節點包含一個線性方程

函數createTree()的僞代碼

找到最佳切分特徵:
    如果該節點不能再分,將該節點存爲葉節點
    執行二元切分
    在右子樹調用createTree()方法
    在左子樹調用createTree()方法

CART算法的實現代碼

from numpy import *
import numpy as np

def loadDataSet(filename):
    # 加載數據集
    '''
    :param filename:
    :return: 數據+標籤列表
    '''
    dataMat = []
    fr = open(filename)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float, curLine))
        dataMat.append(fltLine)
    return dataMat

def binSplitDataSet(dataSet, feature, value):
    '''

    :param dataSet:數據集合
    :param feature: 待切分的特徵
    :param value: 該特徵的某個值
    :return: 通過數組過濾的方式將上述數據集合切分得到兩個子集並返回
    '''
    mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]
    mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
    return mat0, mat1


def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
    '''
    遞歸函數:樹構建
    :param dataSet:數據集 
    :param leafType:建立葉節點函數 
    :param errType: 誤差計算函數
    :param ops: 包含樹構建所需的其他函數的元組(tolS,tolN), tolS:容許的誤差下降值,tolN:切分的最少樣本數
    :return: 
    '''
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    # 滿足停止條件時返回葉節點值
    if feat == None:
        return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

將CART算法用於迴歸

模型樹:把葉節點設定爲分段線性函數,模型樹的可解釋性是其優於迴歸樹的特點之一
誤差計算:對於給定數據集,應先用線性模型來對它進行擬合,然後計算真實目標值和模型預測值之間的差值,最後將這些差值的平方求和就得到了所需的誤差。
補充一些新的代碼,使createTree()運行。
實現chooseBestSplit函數:給定某個誤差計算方法,該函數會找到數據集上的最佳二元切分方式。該函數還要確定什麼時候停止切分,一旦停止切分會生成一個葉節點。即用最佳方式切分數據集和生成相應的葉節點。
僞代碼

對每個特徵:
	對每個特徵值:
		將數據切分成兩份
		計算切分誤差
		如果當前誤差小於當前最小誤差,則將當前切分設定爲最佳切分並更新最小誤差
返回最佳切分的特徵和閾值

python代碼如下

def regLeaf(dataSet):
    '''
    生成葉節點
    :param dataSet: 數據集
    :return:
    '''
    return mean(dataSet[:, -1])


def regErr(dataSet):
    '''
    誤差估計函數,在給定數據上計算目標變量的平方誤差,總方差
    :param dataSet:
    :return:
    '''
    return var(dataSet[:, -1])*shape(dataSet)[0]

def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    '''
    找到數據的最佳二元切分方式
    :param dataSet: 數據集
    :param leafType:建立葉節點的函數
    :param errType: 誤差計算函數
    :param ops: 包含樹構建所需其他參數的元組
    :return: 特徵編號和切分特徵值
    '''
    tolS = ops[0]  # 容許的誤差下降值
    tolN = ops[1]  # 切分的最少樣本數
    if len(set(dataSet[:, -1].T.tolist()[0])) == 1:
        return None, leafType(dataSet)
    m, n = shape(dataSet)  # 當前數據集的大小
    S = errType(dataSet)   # 當前數據集的誤差
    bestS = inf
    bestIndex = 0
    bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0])<tolN or shape(mat1)[0]<tolN:
                continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestS = newS
                bestValue = splitVal
    # 切分後如果誤差減少不大,則不應該進行切分操作,直接創建葉節點
    if S-bestS < tolS:
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    # 切分後的兩個子集大小是否小於用戶自定義的大小,則不應切分
    if shape(mat0)[0] < tolN or shape(mat1)[0] < tolN:
        return None, leafType(dataSet)
    return bestIndex, bestValue

運行測試代碼:

 	 myDat = loadDataSet('ex00.txt')
    # # plot_db(myDat)
     myMat = mat(myDat)
     retTree = createTree(myMat)
     print('retTree  ', retTree)

結果:

retTree   {'right': -0.04465028571428572, 'left': 1.0180967672413792, 'spVal': 0.48813, 'spInd': 0}

可以看到,樹中包含2個葉節點

運行測試代碼:

	myDat1 = loadDataSet('ex0.txt')
    plot_db(myDat1)
    myMat1 = mat(myDat1)
    retTree1 = createTree(myMat1)
    print('retTree1  ', retTree1)

結果:

retTree1   {'spVal': 0.39435, 'left': {'spVal': 0.582002, 'left': {'spVal': 0.797583, 'left': 3.9871632, 'spInd': 1, 'right': 2.9836209534883724}, 'spInd': 1, 'right': 1.980035071428571}, 'spInd': 1, 'right': {'spVal': 0.197834, 'left': 1.0289583666666666, 'spInd': 1, 'right': -0.023838155555555553}}

可以看到,樹中包含5個葉節點

樹剪枝

如果一棵樹的節點過多,表明該模型可能發生了過擬合==。
之前的算法都是使用了測試集上某種交叉驗證技術來發現過擬合。決策樹也是如此。
通過降低決策樹的複雜度來避免過擬合的過程稱爲剪枝
在函數chooseBestSplit()中提前終止條件,實際上是所謂的預剪枝操作。
另一種形式的剪枝需要使用測試集和訓練集,稱作後剪枝

預剪枝

樹構建算法對輸入的參數tolS和tolN非常敏感,如果選用其他值,構建的樹效果不太好,例如,測試代碼:

	myDat = loadDataSet('ex00.txt')
    # plot_db(myDat)
    myMat = mat(myDat)
    retTree = createTree(myMat)
    print('retTree  ', retTree)
	retTree_1 = createTree(myMat, ops=(0, 1))
    print('retTree_1  ', retTree_1)

結果爲:

retTree   {'left': 1.0180967672413792, 'spInd': 0, 'spVal': 0.48813, 'right': -0.04465028571428572}
retTree_1   {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 1.035533, 'spInd': 0, 'spVal': 0.993349, 'right': 1.077553}, 'spInd': 0, 'spVal': 0.989888, 'right': {'left': 0.744207, 'spInd': 0, 'spVal': 0.988852, 'right': 1.069062}}, 'spInd': 0, 'spVal': 0.985425, 'right': 1.227946}, 'spInd': 0, 'spVal': 0.976414, 'right': {'left': {'left': 0.862911, 'spInd': 0, 'spVal': 0.975022, 'right': 0.673579}, 'spInd': 0, 'spVal': 0.953112, 'right': {'left': {'left': 1.06469, 'spInd': 0, 'spVal': 0.951949, 'right': {'left': 0.945255, 'spInd': 0, 'spVal': 0.950153, 'right': 1.022906}}, 'spInd': 0, 'spVal': 0.948268, 'right': {'left': 0.631862, 'spInd': 0, 'spVal': 0.936783, 'right': {'left': {'left': 1.026258, 'spInd': 0, 'spVal': 0.930173, 'right': 1.035645}, 'spInd': 0, 'spVal': 0.928097, 'right': 0.883225}}}}}, 'spInd': 0, 'spVal': 0.919384, 'right': {'left': {'left': {'left': {'left': 1.029889, 'spInd': 0, 'spVal': 0.919074, 'right': 1.123413}, 'spInd': 0, 'spVal': 0.902532, 'right': {'left': 0.861601, 'spInd': 0, 'spVal': 0.901056, 'right': {'left': 1.0559, 'spInd': 0, 'spVal': 0.900272, 'right': 0.996871}}}, 'spInd': 0, 'spVal': 0.897094, 'right': {'left': 1.240209, 'spInd': 0, 'spVal': 0.89593, 'right': {'left': 1.077275, 'spInd': 0, 'spVal': 0.884512, 'right': 1.117833}}}, 'spInd': 0, 'spVal': 0.877241, 'right': {'left': {'left': {'left': 0.797005, 'spInd': 0, 'spVal': 0.869077, 'right': 1.114825}, 'spInd': 0, 'spVal': 0.860049, 'right': 0.71749}, 'spInd': 0, 'spVal': 0.848921, 'right': 1.170959}}}, 'spInd': 0, 'spVal': 0.846455, 'right': {'left': 0.72003, 'spInd': 0, 'spVal': 0.845815, 'right': 0.952617}}, 'spInd': 0, 'spVal': 0.837522, 'right': {'left': {'left': 1.229373, 'spInd': 0, 'spVal': 0.834078, 'right': {'left': 1.01058, 'spInd': 0, 'spVal': 0.824442, 'right': {'left': 1.082153, 'spInd': 0, 'spVal': 0.822443, 'right': 1.086648}}}, 'spInd': 0, 'spVal': 0.821648, 'right': {'left': 1.280895, 'spInd': 0, 'spVal': 0.820802, 'right': 1.325907}}}, 'spInd': 0, 'spVal': 0.819823, 'right': {'left': {'left': {'left': {'left': 0.835264, 'spInd': 0, 'spVal': 0.814825, 'right': 1.095206}, 'spInd': 0, 'spVal': 0.813719, 'right': {'left': 0.706601, 'spInd': 0, 'spVal': 0.804586, 'right': {'left': 0.924033, 'spInd': 0, 'spVal': 0.795072, 'right': 0.965721}}}, 'spInd': 0, 'spVal': 0.79024, 'right': {'left': 0.533214, 'spInd': 0, 'spVal': 0.789625, 'right': 0.552614}}, 'spInd': 0, 'spVal': 0.785541, 'right': {'left': {'left': {'left': {'left': {'left': 1.165296, 'spInd': 0, 'spVal': 0.782167, 'right': {'left': {'left': 0.886049, 'spInd': 0, 'spVal': 0.78193, 'right': 1.074488}, 'spInd': 0, 'spVal': 0.774301, 'right': 0.836763}}, 'spInd': 0, 'spVal': 0.773422, 'right': {'left': {'left': 1.125943, 'spInd': 0, 'spVal': 0.773168, 'right': 1.140917}, 'spInd': 0, 'spVal': 0.772083, 'right': 1.299018}}, 'spInd': 0, 'spVal': 0.768784, 'right': {'left': {'left': {'left': {'left': 0.899705, 'spInd': 0, 'spVal': 0.768596, 'right': 0.760219}, 'spInd': 0, 'spVal': 0.761474, 'right': 1.058262}, 'spInd': 0, 'spVal': 0.750918, 'right': {'left': 0.748104, 'spInd': 0, 'spVal': 0.750078, 'right': 0.906291}}, 'spInd': 0, 'spVal': 0.742527, 'right': {'left': {'left': 1.087056, 'spInd': 0, 'spVal': 0.737189, 'right': 1.200781}, 'spInd': 0, 'spVal': 0.729234, 'right': {'left': 0.931956, 'spInd': 0, 'spVal': 0.727098, 'right': {'left': 1.000567, 'spInd': 0, 'spVal': 0.726828, 'right': 1.017112}}}}}, 'spInd': 0, 'spVal': 0.72312, 'right': 1.307248}, 'spInd': 0, 'spVal': 0.712503, 'right': {'left': {'left': 0.93349, 'spInd': 0, 'spVal': 0.712386, 'right': 0.564858}, 'spInd': 0, 'spVal': 0.703755, 'right': {'left': {'left': {'left': 1.101678, 'spInd': 0, 'spVal': 0.697777, 'right': 0.827805}, 'spInd': 0, 'spVal': 0.697718, 'right': 1.212434}, 'spInd': 0, 'spVal': 0.696648, 'right': {'left': 0.845423, 'spInd': 0, 'spVal': 0.691115, 'right': 0.834391}}}}}}, 'spInd': 0, 'spVal': 0.683921, 'right': {'left': 1.414382, 'spInd': 0, 'spVal': 0.683886, 'right': {'left': {'left': 0.999985, 'spInd': 0, 'spVal': 0.67939, 'right': 1.307217}, 'spInd': 0, 'spVal': 0.678287, 'right': {'left': {'left': 0.907727, 'spInd': 0, 'spVal': 0.673195, 'right': 0.915077}, 'spInd': 0, 'spVal': 0.66387, 'right': 1.187129}}}}, 'spInd': 0, 'spVal': 0.6632, 'right': {'left': {'left': 0.701634, 'spInd': 0, 'spVal': 0.661923, 'right': 0.76704}, 'spInd': 0, 'spVal': 0.656218, 'right': {'left': 0.958506, 'spInd': 0, 'spVal': 0.652121, 'right': 1.004346}}}, 'spInd': 0, 'spVal': 0.651376, 'right': {'left': {'left': 1.315384, 'spInd': 0, 'spVal': 0.648675, 'right': 1.287407}, 'spInd': 0, 'spVal': 0.645762, 'right': {'left': 1.026886, 'spInd': 0, 'spVal': 0.643665, 'right': 1.024241}}}, 'spInd': 0, 'spVal': 0.643601, 'right': {'left': 0.782552, 'spInd': 0, 'spVal': 0.626011, 'right': 0.840544}}, 'spInd': 0, 'spVal': 0.625791, 'right': 1.244731}, 'spInd': 0, 'spVal': 0.625336, 'right': {'left': 0.623696, 'spInd': 0, 'spVal': 0.622398, 'right': 0.76633}}, 'spInd': 0, 'spVal': 0.620599, 'right': {'left': {'left': 1.334421, 'spInd': 0, 'spVal': 0.613765, 'right': 1.621091}, 'spInd': 0, 'spVal': 0.61127, 'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 0.982036, 'spInd': 0, 'spVal': 0.604529, 'right': 1.212685}, 'spInd': 0, 'spVal': 0.597409, 'right': 0.97477}, 'spInd': 0, 'spVal': 0.595012, 'right': {'left': 1.213435, 'spInd': 0, 'spVal': 0.59021, 'right': 1.336661}}, 'spInd': 0, 'spVal': 0.590062, 'right': {'left': 0.705531, 'spInd': 0, 'spVal': 0.589575, 'right': {'left': {'left': {'left': 1.185812, 'spInd': 0, 'spVal': 0.578252, 'right': 0.921885}, 'spInd': 0, 'spVal': 0.576946, 'right': 1.234129}, 'spInd': 0, 'spVal': 0.575805, 'right': {'left': 0.89909, 'spInd': 0, 'spVal': 0.574573, 'right': {'left': {'left': 1.06613, 'spInd': 0, 'spVal': 0.567704, 'right': 0.969058}, 'spInd': 0, 'spVal': 0.561362, 'right': 1.070529}}}}}, 'spInd': 0, 'spVal': 0.559763, 'right': {'left': {'left': {'left': 1.253151, 'spInd': 0, 'spVal': 0.55352, 'right': 1.391273}, 'spInd': 0, 'spVal': 0.55299, 'right': 1.036158}, 'spInd': 0, 'spVal': 0.552381, 'right': 1.36963}}, 'spInd': 0, 'spVal': 0.541314, 'right': {'left': {'left': {'left': 0.893748, 'spInd': 0, 'spVal': 0.539558, 'right': 1.053846}, 'spInd': 0, 'spVal': 0.536689, 'right': {'left': {'left': 0.867284, 'spInd': 0, 'spVal': 0.530897, 'right': 0.893462}, 'spInd': 0, 'spVal': 0.529491, 'right': {'left': 1.022206, 'spInd': 0, 'spVal': 0.527505, 'right': 0.87956}}}, 'spInd': 0, 'spVal': 0.520207, 'right': {'left': 1.209557, 'spInd': 0, 'spVal': 0.520044, 'right': {'left': 0.961983, 'spInd': 0, 'spVal': 0.518735, 'right': 1.037179}}}}, 'spInd': 0, 'spVal': 0.517921, 'right': {'left': 1.493586, 'spInd': 0, 'spVal': 0.514563, 'right': {'left': 1.156648, 'spInd': 0, 'spVal': 0.50794, 'right': 1.107265}}}}}, 'spInd': 0, 'spVal': 0.48813, 'right': {'left': {'left': {'left': {'left': {'left': -0.097791, 'spInd': 0, 'spVal': 0.475976, 'right': {'left': -0.163707, 'spInd': 0, 'spVal': 0.465625, 'right': -0.15294}}, 'spInd': 0, 'spVal': 0.458121, 'right': {'left': {'left': -0.061456, 'spInd': 0, 'spVal': 0.44928, 'right': {'left': 0.069098, 'spInd': 0, 'spVal': 0.448656, 'right': {'left': 0.026974, 'spInd': 0, 'spVal': 0.438367, 'right': 0.034014}}}, 'spInd': 0, 'spVal': 0.429664, 'right': -0.188659}}, 'spInd': 0, 'spVal': 0.41023, 'right': 0.331722}, 'spInd': 0, 'spVal': 0.406649, 'right': {'left': {'left': {'left': {'left': -0.366317, 'spInd': 0, 'spVal': 0.401152, 'right': {'left': -0.12164, 'spInd': 0, 'spVal': 0.378595, 'right': -0.296094}}, 'spInd': 0, 'spVal': 0.377597, 'right': {'left': 0.088505, 'spInd': 0, 'spVal': 0.377201, 'right': -0.24355}}, 'spInd': 0, 'spVal': 0.362314, 'right': {'left': -0.556464, 'spInd': 0, 'spVal': 0.360323, 'right': -0.20483}}, 'spInd': 0, 'spVal': 0.355688, 'right': {'left': {'left': {'left': -0.119399, 'spInd': 0, 'spVal': 0.348013, 'right': 0.048939}, 'spInd': 0, 'spVal': 0.347837, 'right': {'left': {'left': -0.153405, 'spInd': 0, 'spVal': 0.346986, 'right': -0.150389}, 'spInd': 0, 'spVal': 0.344102, 'right': -0.061539}}, 'spInd': 0, 'spVal': 0.343554, 'right': -0.3717}}}, 'spInd': 0, 'spVal': 0.343479, 'right': {'left': {'left': {'left': 0.175264, 'spInd': 0, 'spVal': 0.339563, 'right': 0.206783}, 'spInd': 0, 'spVal': 0.3371, 'right': {'left': 0.026332, 'spInd': 0, 'spVal': 0.332982, 'right': 0.210084}}, 'spInd': 0, 'spVal': 0.325412, 'right': {'left': {'left': {'left': {'left': {'left': {'left': -0.219245, 'spInd': 0, 'spVal': 0.323181, 'right': {'left': 0.180811, 'spInd': 0, 'spVal': 0.314924, 'right': {'left': {'left': {'left': {'left': {'left': -0.001952, 'spInd': 0, 'spVal': 0.306964, 'right': {'left': -0.177321, 'spInd': 0, 'spVal': 0.30554, 'right': {'left': -0.115991, 'spInd': 0, 'spVal': 0.302217, 'right': -0.14865}}}, 'spInd': 0, 'spVal': 0.302001, 'right': {'left': 0.317135, 'spInd': 0, 'spVal': 0.295511, 'right': 0.002882}}, 'spInd': 0, 'spVal': 0.280738, 'right': -0.22888}, 'spInd': 0, 'spVal': 0.278661, 'right': 0.253628}, 'spInd': 0, 'spVal': 0.27394, 'right': {'left': {'left': -0.085713, 'spInd': 0, 'spVal': 0.273147, 'right': {'left': -0.455219, 'spInd': 0, 'spVal': 0.269681, 'right': -0.165971}}, 'spInd': 0, 'spVal': 0.268857, 'right': {'left': 0.073447, 'spInd': 0, 'spVal': 0.252649, 'right': {'left': -0.055613, 'spInd': 0, 'spVal': 0.250744, 'right': {'left': 0.046297, 'spInd': 0, 'spVal': 0.243909, 'right': -0.029467}}}}}}}, 'spInd': 0, 'spVal': 0.242204, 'right': 0.209359}, 'spInd': 0, 'spVal': 0.23807, 'right': {'left': {'left': -0.358459, 'spInd': 0, 'spVal': 0.233115, 'right': -0.348147}, 'spInd': 0, 'spVal': 0.210334, 'right': {'left': {'left': -0.006899, 'spInd': 0, 'spVal': 0.203693, 'right': {'left': {'left': -0.064036, 'spInd': 0, 'spVal': 0.202054, 'right': -0.087744}, 'spInd': 0, 'spVal': 0.196005, 'right': -0.048847}}, 'spInd': 0, 'spVal': 0.193641, 'right': -0.327589}}}, 'spInd': 0, 'spVal': 0.188218, 'right': {'left': {'left': {'left': {'left': {'left': 0.113685, 'spInd': 0, 'spVal': 0.18351, 'right': 0.184843}, 'spInd': 0, 'spVal': 0.180506, 'right': {'left': 0.103676, 'spInd': 0, 'spVal': 0.152324, 'right': 0.132858}}, 'spInd': 0, 'spVal': 0.148049, 'right': 0.204298}, 'spInd': 0, 'spVal': 0.146366, 'right': {'left': 0.034283, 'spInd': 0, 'spVal': 0.145809, 'right': 0.136979}}, 'spInd': 0, 'spVal': 0.1333, 'right': {'left': {'left': -0.223143, 'spInd': 0, 'spVal': 0.132543, 'right': -0.329372}, 'spInd': 0, 'spVal': 0.130962, 'right': {'left': {'left': {'left': 0.184241, 'spInd': 0, 'spVal': 0.130052, 'right': -0.026167}, 'spInd': 0, 'spVal': 0.129061, 'right': 0.305107}, 'spInd': 0, 'spVal': 0.118156, 'right': {'left': {'left': {'left': -0.077409, 'spInd': 0, 'spVal': 0.101149, 'right': {'left': 0.068834, 'spInd': 0, 'spVal': 0.099142, 'right': 0.02528}}, 'spInd': 0, 'spVal': 0.098016, 'right': -0.33276}, 'spInd': 0, 'spVal': 0.096994, 'right': {'left': 0.227167, 'spInd': 0, 'spVal': 0.091358, 'right': {'left': 0.099935, 'spInd': 0, 'spVal': 0.084248, 'right': -0.019547}}}}}}}, 'spInd': 0, 'spVal': 0.081931, 'right': {'left': {'left': -0.269756, 'spInd': 0, 'spVal': 0.074795, 'right': {'left': -0.349692, 'spInd': 0, 'spVal': 0.072243, 'right': -0.420983}}, 'spInd': 0, 'spVal': 0.071769, 'right': {'left': {'left': -0.110946, 'spInd': 0, 'spVal': 0.066172, 'right': 0.052439}, 'spInd': 0, 'spVal': 0.065615, 'right': -0.30697}}}, 'spInd': 0, 'spVal': 0.048014, 'right': {'left': {'left': 0.064496, 'spInd': 0, 'spVal': 0.036492, 'right': {'left': 0.408155, 'spInd': 0, 'spVal': 0.036098, 'right': 0.155096}}, 'spInd': 0, 'spVal': 0.014083, 'right': {'left': -0.132525, 'spInd': 0, 'spVal': 0.009849, 'right': {'left': 0.056594, 'spInd': 0, 'spVal': 0.008307, 'right': {'left': {'left': 0.069976, 'spInd': 0, 'spVal': 0.007044, 'right': 0.09415}, 'spInd': 0, 'spVal': 0.000234, 'right': 0.060903}}}}}}}}

與開始包含兩個節點的樹相比,這棵樹十分臃腫,甚至爲數據集中的每個樣本都分配了一個葉節點。

運行測試代碼:

    myDat2 = loadDataSet('ex2.txt')
    myMat2 = mat(myDat2)
    retTree2 = createTree(myMat2, ops=(0, 1))

    print('retTree2  ', retTree2)
    print('retTree2_1  ', createTree(myMat2, ops=(10000, 4)))

結果:

retTree2   {'left': {'left': {'left': {'left': {'left': 86.399637, 'spInd': 0, 'spVal': 0.968621, 'right': 98.648346}, 'spInd': 0, 'spVal': 0.965969, 'right': {'left': {'left': {'left': 112.386764, 'spInd': 0, 'spVal': 0.960398, 'right': 123.559747}, 'spInd': 0, 'spVal': 0.958512, 'right': 135.837013}, 'spInd': 0, 'spVal': 0.956951, 'right': {'left': {'left': 82.016541, 'spInd': 0, 'spVal': 0.954711, 'right': 100.935789}, 'spInd': 0, 'spVal': 0.953902, 'right': 130.92648}}}, 'spInd': 0, 'spVal': 0.952833, 'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 100.649591, 'spInd': 0, 'spVal': 0.952377, 'right': 73.520802}, 'spInd': 0, 'spVal': 0.949198, 'right': 105.752508}, 'spInd': 0, 'spVal': 0.948822, 'right': 69.318649}, 'spInd': 0, 'spVal': 0.944221, 'right': {'left': {'left': {'left': 100.120253, 'spInd': 0, 'spVal': 0.937766, 'right': 119.949824}, 'spInd': 0, 'spVal': 0.936524, 'right': {'left': 65.548418, 'spInd': 0, 'spVal': 0.934853, 'right': {'left': 115.753994, 'spInd': 0, 'spVal': 0.925782, 'right': {'left': {'left': {'left': 92.074619, 'spInd': 0, 'spVal': 0.915263, 'right': 96.71761}, 'spInd': 0, 'spVal': 0.912161, 'right': 85.005351}, 'spInd': 0, 'spVal': 0.910975, 'right': {'left': {'left': 106.814667, 'spInd': 0, 'spVal': 0.908629, 'right': 118.513475}, 'spInd': 0, 'spVal': 0.901444, 'right': {'left': 87.300625, 'spInd': 0, 'spVal': 0.901421, 'right': {'left': {'left': 100.133819, 'spInd': 0, 'spVal': 0.900699, 'right': {'left': 109.188248, 'spInd': 0, 'spVal': 0.896683, 'right': 107.00162}}, 'spInd': 0, 'spVal': 0.892999, 'right': {'left': 82.436686, 'spInd': 0, 'spVal': 0.888426, 'right': {'left': {'left': {'left': 94.896354, 'spInd': 0, 'spVal': 0.885676, 'right': 108.045948}, 'spInd': 0, 'spVal': 0.883615, 'right': {'left': 95.348184, 'spInd': 0, 'spVal': 0.872883, 'right': 95.887712}}, 'spInd': 0, 'spVal': 0.872199, 'right': {'left': 111.552716, 'spInd': 0, 'spVal': 0.866451, 'right': {'left': 94.402102, 'spInd': 0, 'spVal': 0.856421, 'right': 107.166848}}}}}}}}}}}, 'spInd': 0, 'spVal': 0.85497, 'right': {'left': {'left': 89.20993, 'spInd': 0, 'spVal': 0.847219, 'right': 76.240984}, 'spInd': 0, 'spVal': 0.84294, 'right': 95.893131}}}, 'spInd': 0, 'spVal': 0.841625, 'right': 60.552308}, 'spInd': 0, 'spVal': 0.841547, 'right': {'left': 115.669032, 'spInd': 0, 'spVal': 0.838587, 'right': 134.089674}}, 'spInd': 0, 'spVal': 0.833026, 'right': {'left': 76.723835, 'spInd': 0, 'spVal': 0.823848, 'right': {'left': 59.342323, 'spInd': 0, 'spVal': 0.819722, 'right': 70.054508}}}, 'spInd': 0, 'spVal': 0.815215, 'right': {'left': 118.319942, 'spInd': 0, 'spVal': 0.811602, 'right': {'left': 99.841379, 'spInd': 0, 'spVal': 0.811363, 'right': 112.981216}}}, 'spInd': 0, 'spVal': 0.806158, 'right': {'left': 62.877698, 'spInd': 0, 'spVal': 0.799873, 'right': {'left': 91.368473, 'spInd': 0, 'spVal': 0.798198, 'right': 76.853728}}}, 'spInd': 0, 'spVal': 0.790312, 'right': {'left': {'left': 110.15973, 'spInd': 0, 'spVal': 0.787755, 'right': 118.642009}, 'spInd': 0, 'spVal': 0.786865, 'right': {'left': 100.598825, 'spInd': 0, 'spVal': 0.785574, 'right': {'left': 107.024467, 'spInd': 0, 'spVal': 0.777582, 'right': 100.838446}}}}, 'spInd': 0, 'spVal': 0.769043, 'right': 64.041941}, 'spInd': 0, 'spVal': 0.763328, 'right': 115.199195}, 'spInd': 0, 'spVal': 0.759504, 'right': {'left': {'left': 81.106762, 'spInd': 0, 'spVal': 0.757527, 'right': 63.549854}, 'spInd': 0, 'spVal': 0.740859, 'right': {'left': 93.773929, 'spInd': 0, 'spVal': 0.731636, 'right': 73.912028}}}}, 'spInd': 0, 'spVal': 0.729397, 'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 110.90283, 'spInd': 0, 'spVal': 0.716211, 'right': {'left': 103.345308, 'spInd': 0, 'spVal': 0.710234, 'right': 108.553919}}, 'spInd': 0, 'spVal': 0.70889, 'right': 135.416767}, 'spInd': 0, 'spVal': 0.706961, 'right': {'left': {'left': {'left': {'left': 106.180427, 'spInd': 0, 'spVal': 0.70639, 'right': 105.062147}, 'spInd': 0, 'spVal': 0.699873, 'right': 115.586605}, 'spInd': 0, 'spVal': 0.69892, 'right': 92.470636}, 'spInd': 0, 'spVal': 0.698472, 'right': {'left': 120.521925, 'spInd': 0, 'spVal': 0.689099, 'right': {'left': {'left': {'left': 112.378209, 'spInd': 0, 'spVal': 0.680486, 'right': 110.367074}, 'spInd': 0, 'spVal': 0.667851, 'right': 92.449664}, 'spInd': 0, 'spVal': 0.666452, 'right': {'left': 120.014736, 'spInd': 0, 'spVal': 0.665652, 'right': 105.547997}}}}}, 'spInd': 0, 'spVal': 0.665329, 'right': {'left': 121.980607, 'spInd': 0, 'spVal': 0.661073, 'right': {'left': 115.687524, 'spInd': 0, 'spVal': 0.652462, 'right': 112.715799}}}, 'spInd': 0, 'spVal': 0.642707, 'right': 82.500766}, 'spInd': 0, 'spVal': 0.642373, 'right': 140.613941}, 'spInd': 0, 'spVal': 0.640515, 'right': {'left': {'left': {'left': {'left': 82.713621, 'spInd': 0, 'spVal': 0.637999, 'right': {'left': 91.656617, 'spInd': 0, 'spVal': 0.632691, 'right': 93.645293}}, 'spInd': 0, 'spVal': 0.628061, 'right': {'left': 117.628346, 'spInd': 0, 'spVal': 0.624827, 'right': 105.970743}}, 'spInd': 0, 'spVal': 0.623909, 'right': {'left': 87.181863, 'spInd': 0, 'spVal': 0.618868, 'right': 76.917665}}, 'spInd': 0, 'spVal': 0.613004, 'right': {'left': 168.180746, 'spInd': 0, 'spVal': 0.606417, 'right': {'left': {'left': {'left': {'left': {'left': {'left': 93.521396, 'spInd': 0, 'spVal': 0.599142, 'right': {'left': 130.378529, 'spInd': 0, 'spVal': 0.589806, 'right': {'left': {'left': 98.674874, 'spInd': 0, 'spVal': 0.585413, 'right': 125.295113}, 'spInd': 0, 'spVal': 0.582311, 'right': {'left': 82.589328, 'spInd': 0, 'spVal': 0.571214, 'right': {'left': 114.872056, 'spInd': 0, 'spVal': 0.569327, 'right': 108.435392}}}}}, 'spInd': 0, 'spVal': 0.560301, 'right': 82.903945}, 'spInd': 0, 'spVal': 0.553797, 'right': {'left': 120.857321, 'spInd': 0, 'spVal': 0.549814, 'right': 137.267576}}, 'spInd': 0, 'spVal': 0.548539, 'right': {'left': 83.114502, 'spInd': 0, 'spVal': 0.546601, 'right': {'left': {'left': 96.319043, 'spInd': 0, 'spVal': 0.543843, 'right': 98.36201}, 'spInd': 0, 'spVal': 0.537834, 'right': 90.995536}}}, 'spInd': 0, 'spVal': 0.533511, 'right': {'left': {'left': 129.766743, 'spInd': 0, 'spVal': 0.531944, 'right': 124.795495}, 'spInd': 0, 'spVal': 0.51915, 'right': 116.176162}}, 'spInd': 0, 'spVal': 0.513332, 'right': {'left': 101.075609, 'spInd': 0, 'spVal': 0.508548, 'right': {'left': 93.292829, 'spInd': 0, 'spVal': 0.508542, 'right': 96.403373}}}}}}}, 'spInd': 0, 'spVal': 0.499171, 'right': {'left': {'left': {'left': {'left': {'left': {'left': 11.924204, 'spInd': 0, 'spVal': 0.487537, 'right': 5.149336}, 'spInd': 0, 'spVal': 0.487381, 'right': 27.729263}, 'spInd': 0, 'spVal': 0.483803, 'right': 5.224234}, 'spInd': 0, 'spVal': 0.467383, 'right': {'left': -9.712925, 'spInd': 0, 'spVal': 0.46568, 'right': -23.777531}}, 'spInd': 0, 'spVal': 0.465561, 'right': {'left': 30.051931, 'spInd': 0, 'spVal': 0.463241, 'right': 17.171057}}, 'spInd': 0, 'spVal': 0.457563, 'right': {'left': -34.044555, 'spInd': 0, 'spVal': 0.455761, 'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 9.841938, 'spInd': 0, 'spVal': 0.454375, 'right': 3.043912}, 'spInd': 0, 'spVal': 0.454312, 'right': {'left': {'left': -20.360067, 'spInd': 0, 'spVal': 0.451087, 'right': -28.724685}, 'spInd': 0, 'spVal': 0.446196, 'right': -5.108172}}, 'spInd': 0, 'spVal': 0.437652, 'right': {'left': {'left': {'left': {'left': 19.745224, 'spInd': 0, 'spVal': 0.428582, 'right': 15.224266}, 'spInd': 0, 'spVal': 0.426711, 'right': -21.594268}, 'spInd': 0, 'spVal': 0.418943, 'right': 44.161493}, 'spInd': 0, 'spVal': 0.412516, 'right': {'left': -26.419289, 'spInd': 0, 'spVal': 0.403228, 'right': {'left': -1.729244, 'spInd': 0, 'spVal': 0.391609, 'right': 3.001104}}}}, 'spInd': 0, 'spVal': 0.388789, 'right': {'left': 21.578007, 'spInd': 0, 'spVal': 0.385021, 'right': 24.816941}}, 'spInd': 0, 'spVal': 0.382037, 'right': {'left': {'left': {'left': -29.007783, 'spInd': 0, 'spVal': 0.378965, 'right': {'left': {'left': 13.583555, 'spInd': 0, 'spVal': 0.377383, 'right': 5.241196}, 'spInd': 0, 'spVal': 0.373501, 'right': -8.228297}}, 'spInd': 0, 'spVal': 0.370042, 'right': {'left': -32.124495, 'spInd': 0, 'spVal': 0.35679, 'right': {'left': {'left': -19.526539, 'spInd': 0, 'spVal': 0.351478, 'right': -0.461116}, 'spInd': 0, 'spVal': 0.350725, 'right': {'left': -40.086564, 'spInd': 0, 'spVal': 0.350065, 'right': {'left': -1.319852, 'spInd': 0, 'spVal': 0.342761, 'right': {'left': -31.584855, 'spInd': 0, 'spVal': 0.342155, 'right': {'left': -16.930416, 'spInd': 0, 'spVal': 0.3417, 'right': -23.547711}}}}}}}, 'spInd': 0, 'spVal': 0.335182, 'right': {'left': {'left': {'left': {'left': 2.768225, 'spInd': 0, 'spVal': 0.3349, 'right': 18.97665}, 'spInd': 0, 'spVal': 0.331364, 'right': -1.290825}, 'spInd': 0, 'spVal': 0.32889, 'right': 39.783113}, 'spInd': 0, 'spVal': 0.324274, 'right': {'left': {'left': {'left': -13.189243, 'spInd': 0, 'spVal': 0.318309, 'right': -27.605424}, 'spInd': 0, 'spVal': 0.310956, 'right': -49.939516}, 'spInd': 0, 'spVal': 0.309133, 'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 8.814725, 'spInd': 0, 'spVal': 0.300318, 'right': {'left': -18.051318, 'spInd': 0, 'spVal': 0.297107, 'right': {'left': -1.798377, 'spInd': 0, 'spVal': 0.295993, 'right': {'left': -14.988279, 'spInd': 0, 'spVal': 0.290749, 'right': -14.391613}}}}, 'spInd': 0, 'spVal': 0.284794, 'right': {'left': 35.623746, 'spInd': 0, 'spVal': 0.273863, 'right': {'left': -9.457556, 'spInd': 0, 'spVal': 0.264926, 'right': {'left': 5.280579, 'spInd': 0, 'spVal': 0.264639, 'right': 2.557923}}}}, 'spInd': 0, 'spVal': 0.25807, 'right': {'left': {'left': {'left': -20.425137, 'spInd': 0, 'spVal': 0.232802, 'right': 1.222318}, 'spInd': 0, 'spVal': 0.228751, 'right': -30.812912}, 'spInd': 0, 'spVal': 0.228628, 'right': -2.266273}}, 'spInd': 0, 'spVal': 0.228473, 'right': {'left': {'left': 19.425158, 'spInd': 0, 'spVal': 0.2232, 'right': 15.501642}, 'spInd': 0, 'spVal': 0.222271, 'right': {'left': -9.255852, 'spInd': 0, 'spVal': 0.218321, 'right': {'left': 1.410768, 'spInd': 0, 'spVal': 0.217214, 'right': -3.958752}}}}, 'spInd': 0, 'spVal': 0.211633, 'right': {'left': {'left': {'left': -8.332207, 'spInd': 0, 'spVal': 0.206207, 'right': -12.619036}, 'spInd': 0, 'spVal': 0.203993, 'right': -22.379119}, 'spInd': 0, 'spVal': 0.202161, 'right': {'left': -1.983889, 'spInd': 0, 'spVal': 0.199903, 'right': -3.372472}}}, 'spInd': 0, 'spVal': 0.193282, 'right': {'left': 18.208423, 'spInd': 0, 'spVal': 0.176523, 'right': 0.946348}}, 'spInd': 0, 'spVal': 0.166765, 'right': {'left': {'left': {'left': -14.740059, 'spInd': 0, 'spVal': 0.166431, 'right': -6.512506}, 'spInd': 0, 'spVal': 0.164134, 'right': -27.405211}, 'spInd': 0, 'spVal': 0.156273, 'right': 0.225886}}, 'spInd': 0, 'spVal': 0.156067, 'right': {'left': 7.557349, 'spInd': 0, 'spVal': 0.13988, 'right': 7.336784}}, 'spInd': 0, 'spVal': 0.138619, 'right': -29.087463}, 'spInd': 0, 'spVal': 0.131833, 'right': 22.478291}}}}}, 'spInd': 0, 'spVal': 0.130626, 'right': -39.524461}, 'spInd': 0, 'spVal': 0.126833, 'right': {'left': 22.891675, 'spInd': 0, 'spVal': 0.124723, 'right': {'left': {'left': {'left': -1.402796, 'spInd': 0, 'spVal': 0.11515, 'right': 13.795828}, 'spInd': 0, 'spVal': 0.108801, 'right': {'left': -16.106164, 'spInd': 0, 'spVal': 0.10796, 'right': {'left': -1.293195, 'spInd': 0, 'spVal': 0.085873, 'right': -10.137104}}}, 'spInd': 0, 'spVal': 0.085111, 'right': {'left': 37.820659, 'spInd': 0, 'spVal': 0.084661, 'right': {'left': -24.132226, 'spInd': 0, 'spVal': 0.080061, 'right': {'left': {'left': 2.229873, 'spInd': 0, 'spVal': 0.079632, 'right': 29.420068}, 'spInd': 0, 'spVal': 0.068373, 'right': {'left': -15.160836, 'spInd': 0, 'spVal': 0.061219, 'right': {'left': {'left': {'left': 6.695567, 'spInd': 0, 'spVal': 0.055862, 'right': -3.131497}, 'spInd': 0, 'spVal': 0.053764, 'right': -13.731698}, 'spInd': 0, 'spVal': 0.044737, 'right': {'left': {'left': 3.855393, 'spInd': 0, 'spVal': 0.039914, 'right': 11.220099}, 'spInd': 0, 'spVal': 0.028546, 'right': {'left': -8.377094, 'spInd': 0, 'spVal': 0.000256, 'right': 9.668106}}}}}}}}}}}}}
retTree2_1   {'left': 101.35815937735848, 'spInd': 0, 'spVal': 0.499171, 'right': -2.637719329787234}

當tolS取0時,樹的葉子結點過分多!當取值爲10000時,只有兩個葉子結點,這是因爲停止條件tols對誤差的數量級非常敏感

預剪枝缺點:通過不斷修改停止條件來得到合理結果不是很好的辦法。

後剪枝

後剪枝方法需要將數據集分爲測試集和訓練集。首先指定參數,是的構建出的樹足夠大,足夠複雜,便於剪枝。接下來從上而下找到葉節點,用測試機來判斷將這些葉節點合併是否能夠降低測試誤差,如果是的話就合併。
函數prune()僞代碼:

基於已有的樹切分測試數據:
	如果存在任一子集是一棵樹,則在該子集遞歸剪枝過程
	計算將當前兩個葉節點合併後的誤差
	計算不合並的誤差
	如果合併會降低誤差的話,就將葉節點合併

迴歸樹剪枝函數python代碼:

def isTree(obj):
    '''
    測試輸入變量是否是一棵樹,判斷當前處理的節點是否是葉節點
    :param obj:
    :return: 布爾類型
    '''
    return type(obj).__name__ == 'dict'


def getMean(tree):
    '''
    從上往下遍歷樹直至找到葉節點爲止,遞歸
    :param tree:
    :return: 找到兩個葉節點則計算他們的平均值(塌陷處理)
    '''
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
    if isTree(tree['left']):
        tree['left'] = getMean(tree['left'])
    return (tree['left']+tree['right'])/2.0

def prune(tree, testData):
    '''

    :param tree:待剪枝的樹
    :param testData: 剪枝所需的測試數據
    :return:
    '''
    # 沒有測試數據則對樹進行塌陷處理
    if shape(testData)[0] == 0:
        return getMean(tree)

    if isTree(tree['right']) or isTree(tree['left']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']):
        tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']):
        tree['right'] = prune(tree['right'], rSet)
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(power(lSet[:, -1]-tree['left'], 2)) + sum(power(rSet[:, -1]-tree['right'], 2))
        treeMean = (tree['left']+tree['right'])/2.0
        errorMerge = sum(power(testData[:, -1]-treeMean, 2))
        if errorMerge < errorNoMerge:
            print('merging!')
            return treeMean
        else:
            return tree
    else:
        return tree

測試代碼:

    myDat2 = loadDataSet('ex2.txt')
    myMat2 = mat(myDat2)
    retTree2 = createTree(myMat2, ops=(0, 1))
    # 後剪枝:利用測試集來對樹進行剪枝,由於不需要用戶指定參數,是一種更理想化的剪枝方法。
    myDatTest = loadDataSet('ex2test.txt')
    myMat2Test = mat(myDatTest)
    print('後剪枝:', prune(retTree2, myMat2Test))

結果:

E:\Anaconda2\envs\py3\python.exe D:/pycharmworkshop/ML_in_action/regression_8_9/9_CART_Regression/regTrees.py
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
merging!
後剪枝: {'left': {'left': {'left': {'left': 92.5239915, 'spVal': 0.965969, 'spInd': 0, 'right': {'left': {'left': {'left': 112.386764, 'spVal': 0.960398, 'spInd': 0, 'right': 123.559747}, 'spVal': 0.958512, 'spInd': 0, 'right': 135.837013}, 'spVal': 0.956951, 'spInd': 0, 'right': 111.2013225}}, 'spVal': 0.952833, 'spInd': 0, 'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 96.41885225, 'spVal': 0.948822, 'spInd': 0, 'right': 69.318649}, 'spVal': 0.944221, 'spInd': 0, 'right': {'left': {'left': 110.03503850000001, 'spVal': 0.936524, 'spInd': 0, 'right': {'left': 65.548418, 'spVal': 0.934853, 'spInd': 0, 'right': {'left': 115.753994, 'spVal': 0.925782, 'spInd': 0, 'right': {'left': {'left': 94.3961145, 'spVal': 0.912161, 'spInd': 0, 'right': 85.005351}, 'spVal': 0.910975, 'spInd': 0, 'right': {'left': {'left': 106.814667, 'spVal': 0.908629, 'spInd': 0, 'right': 118.513475}, 'spVal': 0.901444, 'spInd': 0, 'right': {'left': 87.300625, 'spVal': 0.901421, 'spInd': 0, 'right': {'left': {'left': 100.133819, 'spVal': 0.900699, 'spInd': 0, 'right': 108.094934}, 'spVal': 0.892999, 'spInd': 0, 'right': {'left': 82.436686, 'spVal': 0.888426, 'spInd': 0, 'right': {'left': 98.54454949999999, 'spVal': 0.872199, 'spInd': 0, 'right': 106.16859550000001}}}}}}}}}, 'spVal': 0.85497, 'spInd': 0, 'right': {'left': {'left': 89.20993, 'spVal': 0.847219, 'spInd': 0, 'right': 76.240984}, 'spVal': 0.84294, 'spInd': 0, 'right': 95.893131}}}, 'spVal': 0.841625, 'spInd': 0, 'right': 60.552308}, 'spVal': 0.841547, 'spInd': 0, 'right': 124.87935300000001}, 'spVal': 0.833026, 'spInd': 0, 'right': {'left': 76.723835, 'spVal': 0.823848, 'spInd': 0, 'right': {'left': 59.342323, 'spVal': 0.819722, 'spInd': 0, 'right': 70.054508}}}, 'spVal': 0.815215, 'spInd': 0, 'right': {'left': 118.319942, 'spVal': 0.811602, 'spInd': 0, 'right': {'left': 99.841379, 'spVal': 0.811363, 'spInd': 0, 'right': 112.981216}}}, 'spVal': 0.806158, 'spInd': 0, 'right': 73.49439925}, 'spVal': 0.790312, 'spInd': 0, 'right': {'left': 114.4008695, 'spVal': 0.786865, 'spInd': 0, 'right': 102.26514075}}, 'spVal': 0.769043, 'spInd': 0, 'right': 64.041941}, 'spVal': 0.763328, 'spInd': 0, 'right': 115.199195}, 'spVal': 0.759504, 'spInd': 0, 'right': 78.08564325}}, 'spVal': 0.729397, 'spInd': 0, 'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 110.90283, 'spVal': 0.716211, 'spInd': 0, 'right': {'left': 103.345308, 'spVal': 0.710234, 'spInd': 0, 'right': 108.553919}}, 'spVal': 0.70889, 'spInd': 0, 'right': 135.416767}, 'spVal': 0.706961, 'spInd': 0, 'right': {'left': {'left': {'left': {'left': 106.180427, 'spVal': 0.70639, 'spInd': 0, 'right': 105.062147}, 'spVal': 0.699873, 'spInd': 0, 'right': 115.586605}, 'spVal': 0.69892, 'spInd': 0, 'right': 92.470636}, 'spVal': 0.698472, 'spInd': 0, 'right': {'left': 120.521925, 'spVal': 0.689099, 'spInd': 0, 'right': {'left': 101.91115275, 'spVal': 0.666452, 'spInd': 0, 'right': 112.78136649999999}}}}, 'spVal': 0.665329, 'spInd': 0, 'right': {'left': 121.980607, 'spVal': 0.661073, 'spInd': 0, 'right': {'left': 115.687524, 'spVal': 0.652462, 'spInd': 0, 'right': 112.715799}}}, 'spVal': 0.642707, 'spInd': 0, 'right': 82.500766}, 'spVal': 0.642373, 'spInd': 0, 'right': 140.613941}, 'spVal': 0.640515, 'spInd': 0, 'right': {'left': {'left': {'left': {'left': 82.713621, 'spVal': 0.637999, 'spInd': 0, 'right': {'left': 91.656617, 'spVal': 0.632691, 'spInd': 0, 'right': 93.645293}}, 'spVal': 0.628061, 'spInd': 0, 'right': {'left': 117.628346, 'spVal': 0.624827, 'spInd': 0, 'right': 105.970743}}, 'spVal': 0.623909, 'spInd': 0, 'right': 82.04976400000001}, 'spVal': 0.613004, 'spInd': 0, 'right': {'left': 168.180746, 'spVal': 0.606417, 'spInd': 0, 'right': {'left': {'left': {'left': {'left': {'left': {'left': 93.521396, 'spVal': 0.599142, 'spInd': 0, 'right': {'left': 130.378529, 'spVal': 0.589806, 'spInd': 0, 'right': {'left': 111.9849935, 'spVal': 0.582311, 'spInd': 0, 'right': {'left': 82.589328, 'spVal': 0.571214, 'spInd': 0, 'right': {'left': 114.872056, 'spVal': 0.569327, 'spInd': 0, 'right': 108.435392}}}}}, 'spVal': 0.560301, 'spInd': 0, 'right': 82.903945}, 'spVal': 0.553797, 'spInd': 0, 'right': 129.0624485}, 'spVal': 0.548539, 'spInd': 0, 'right': {'left': 83.114502, 'spVal': 0.546601, 'spInd': 0, 'right': {'left': 97.3405265, 'spVal': 0.537834, 'spInd': 0, 'right': 90.995536}}}, 'spVal': 0.533511, 'spInd': 0, 'right': {'left': {'left': 129.766743, 'spVal': 0.531944, 'spInd': 0, 'right': 124.795495}, 'spVal': 0.51915, 'spInd': 0, 'right': 116.176162}}, 'spVal': 0.513332, 'spInd': 0, 'right': {'left': 101.075609, 'spVal': 0.508548, 'spInd': 0, 'right': {'left': 93.292829, 'spVal': 0.508542, 'spInd': 0, 'right': 96.403373}}}}}}}, 'spVal': 0.499171, 'spInd': 0, 'right': {'left': {'left': {'left': {'left': {'left': 8.53677, 'spVal': 0.487381, 'spInd': 0, 'right': 27.729263}, 'spVal': 0.483803, 'spInd': 0, 'right': 5.224234}, 'spVal': 0.467383, 'spInd': 0, 'right': {'left': -9.712925, 'spVal': 0.46568, 'spInd': 0, 'right': -23.777531}}, 'spVal': 0.465561, 'spInd': 0, 'right': {'left': 30.051931, 'spVal': 0.463241, 'spInd': 0, 'right': 17.171057}}, 'spVal': 0.457563, 'spInd': 0, 'right': {'left': -34.044555, 'spVal': 0.455761, 'spInd': 0, 'right': {'left': {'left': {'left': {'left': {'left': -4.1911745, 'spVal': 0.437652, 'spInd': 0, 'right': {'left': {'left': {'left': {'left': 19.745224, 'spVal': 0.428582, 'spInd': 0, 'right': 15.224266}, 'spVal': 0.426711, 'spInd': 0, 'right': -21.594268}, 'spVal': 0.418943, 'spInd': 0, 'right': 44.161493}, 'spVal': 0.412516, 'spInd': 0, 'right': {'left': -26.419289, 'spVal': 0.403228, 'spInd': 0, 'right': 0.6359300000000001}}}, 'spVal': 0.388789, 'spInd': 0, 'right': 23.197474}, 'spVal': 0.382037, 'spInd': 0, 'right': {'left': {'left': {'left': -29.007783, 'spVal': 0.378965, 'spInd': 0, 'right': {'left': {'left': 13.583555, 'spVal': 0.377383, 'spInd': 0, 'right': 5.241196}, 'spVal': 0.373501, 'spInd': 0, 'right': -8.228297}}, 'spVal': 0.370042, 'spInd': 0, 'right': {'left': -32.124495, 'spVal': 0.35679, 'spInd': 0, 'right': {'left': -9.9938275, 'spVal': 0.350725, 'spInd': 0, 'right': -26.851234812500003}}}, 'spVal': 0.335182, 'spInd': 0, 'right': {'left': 22.286959625, 'spVal': 0.324274, 'spInd': 0, 'right': {'left': {'left': -20.3973335, 'spVal': 0.310956, 'spInd': 0, 'right': -49.939516}, 'spVal': 0.309133, 'spInd': 0, 'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 8.814725, 'spVal': 0.300318, 'spInd': 0, 'right': {'left': -18.051318, 'spVal': 0.297107, 'spInd': 0, 'right': {'left': -1.798377, 'spVal': 0.295993, 'spInd': 0, 'right': {'left': -14.988279, 'spVal': 0.290749, 'spInd': 0, 'right': -14.391613}}}}, 'spVal': 0.284794, 'spInd': 0, 'right': {'left': 35.623746, 'spVal': 0.273863, 'spInd': 0, 'right': {'left': -9.457556, 'spVal': 0.264926, 'spInd': 0, 'right': {'left': 5.280579, 'spVal': 0.264639, 'spInd': 0, 'right': 2.557923}}}}, 'spVal': 0.25807, 'spInd': 0, 'right': {'left': {'left': -9.601409499999999, 'spVal': 0.228751, 'spInd': 0, 'right': -30.812912}, 'spVal': 0.228628, 'spInd': 0, 'right': -2.266273}}, 'spVal': 0.228473, 'spInd': 0, 'right': 6.099239}, 'spVal': 0.211633, 'spInd': 0, 'right': {'left': -16.42737025, 'spVal': 0.202161, 'spInd': 0, 'right': -2.6781805}}, 'spVal': 0.193282, 'spInd': 0, 'right': 9.5773855}, 'spVal': 0.166765, 'spInd': 0, 'right': {'left': {'left': {'left': -14.740059, 'spVal': 0.166431, 'spInd': 0, 'right': -6.512506}, 'spVal': 0.164134, 'spInd': 0, 'right': -27.405211}, 'spVal': 0.156273, 'spInd': 0, 'right': 0.225886}}, 'spVal': 0.156067, 'spInd': 0, 'right': {'left': 7.557349, 'spVal': 0.13988, 'spInd': 0, 'right': 7.336784}}, 'spVal': 0.138619, 'spInd': 0, 'right': -29.087463}, 'spVal': 0.131833, 'spInd': 0, 'right': 22.478291}}}}}, 'spVal': 0.130626, 'spInd': 0, 'right': -39.524461}, 'spVal': 0.126833, 'spInd': 0, 'right': {'left': 22.891675, 'spVal': 0.124723, 'spInd': 0, 'right': {'left': {'left': 6.196516, 'spVal': 0.108801, 'spInd': 0, 'right': {'left': -16.106164, 'spVal': 0.10796, 'spInd': 0, 'right': {'left': -1.293195, 'spVal': 0.085873, 'spInd': 0, 'right': -10.137104}}}, 'spVal': 0.085111, 'spInd': 0, 'right': {'left': 37.820659, 'spVal': 0.084661, 'spInd': 0, 'right': {'left': -24.132226, 'spVal': 0.080061, 'spInd': 0, 'right': {'left': 15.824970500000001, 'spVal': 0.068373, 'spInd': 0, 'right': {'left': -15.160836, 'spVal': 0.061219, 'spInd': 0, 'right': {'left': {'left': {'left': 6.695567, 'spVal': 0.055862, 'spInd': 0, 'right': -3.131497}, 'spVal': 0.053764, 'spInd': 0, 'right': -13.731698}, 'spVal': 0.044737, 'spInd': 0, 'right': 4.091626}}}}}}}}}}}

Process finished with exit code 0

可以看到大量的葉節點被剪枝掉了,但沒有像預期那樣剪成兩部分,後剪枝效果不一定比預剪枝效果有效,一般爲了尋求最佳模型可以同時使用兩種剪枝技術。

scikit-learn中有兩類決策樹

均採用優化的CART決策樹算法
參數詳解如下:

# 分類決策樹
from sklearn.tree import DecisionTreeClassifier
DecisionTreeClassifier(criterion="gini",
                 splitter="best",
                 max_depth=None,
                 min_samples_split=2,
                 min_samples_leaf=1,
                 min_weight_fraction_leaf=0.,
                 max_features=None,
                 random_state=None,
                 max_leaf_nodes=None,
                 min_impurity_decrease=0.,
                 min_impurity_split=None,
                 class_weight=None,
                 presort=False)
參數含義:
1.criterion:string, optional (default="gini")
            (1).criterion='gini',分裂節點時評價準則是Gini指數。
            (2).criterion='entropy',分裂節點時的評價指標是信息增益。相當於ID3算法
2.max_depth:int or None, optional (default=None)。指定樹的最大深度。
            如果爲None,表示樹的深度不限。直到所有的葉子節點都是純淨的,即葉子節點
            中所有的樣本點都屬於同一個類別。或者每個葉子節點包含的樣本數小於min_samples_split。
3.splitter:string, optional (default="best")。指定分裂節點時的策略。
           (1).splitter='best',表示選擇最優的分裂策略。
           (2).splitter='random',表示選擇最好的隨機切分策略。
4.min_samples_split:int, float, optional (default=2)。表示分裂一個內部節點需要的最少樣本數。
           (1).如果爲整數,則min_samples_split就是最少樣本數。
           (2).如果爲浮點數(0到1之間),則每次分裂最少樣本數爲ceil(min_samples_split * n_samples)
5.min_samples_leaf: int, float, optional (default=1)。指定每個葉子節點需要的最少樣本數。
           (1).如果爲整數,則min_samples_split就是最少樣本數。
           (2).如果爲浮點數(0到1之間),則每個葉子節點最少樣本數爲ceil(min_samples_leaf * n_samples)
6.min_weight_fraction_leaf:float, optional (default=0.)
           指定葉子節點中樣本的最小權重。
7.max_features:int, float, string or None, optional (default=None).
           搜尋最佳劃分的時候考慮的特徵數量。
           (1).如果爲整數,每次分裂只考慮max_features個特徵。
           (2).如果爲浮點數(0到1之間),每次切分只考慮int(max_features * n_features)個特徵。
           (3).如果爲'auto'或者'sqrt',則每次切分只考慮sqrt(n_features)個特徵
           (4).如果爲'log2',則每次切分只考慮log2(n_features)個特徵。
           (5).如果爲None,則每次切分考慮n_features個特徵。
           (6).如果已經考慮了max_features個特徵,但還是沒有找到一個有效的切分,那麼還會繼續尋找
           下一個特徵,直到找到一個有效的切分爲止。
8.random_state:int, RandomState instance or None, optional (default=None)
           (1).如果爲整數,則它指定了隨機數生成器的種子。
           (2).如果爲RandomState實例,則指定了隨機數生成器。
           (3).如果爲None,則使用默認的隨機數生成器。
9.max_leaf_nodes: int or None, optional (default=None)。指定了葉子節點的最大數量。
           (1).如果爲None,葉子節點數量不限。
           (2).如果爲整數,則max_depth被忽略。
10.min_impurity_decrease:float, optional (default=0.)
         如果節點的分裂導致不純度的減少(分裂後樣本比分裂前更加純淨)大於或等於min_impurity_decrease,則分裂該節點。
         加權不純度的減少量計算公式爲:
         min_impurity_decrease=N_t / N * (impurity - N_t_R / N_t * right_impurity
                            - N_t_L / N_t * left_impurity)
         其中N是樣本的總數,N_t是當前節點的樣本數,N_t_L是分裂後左子節點的樣本數,
         N_t_R是分裂後右子節點的樣本數。impurity指當前節點的基尼指數,right_impurity指
         分裂後右子節點的基尼指數。left_impurity指分裂後左子節點的基尼指數。
11.min_impurity_split:float
         樹生長過程中早停止的閾值。如果當前節點的不純度高於閾值,節點將分裂,否則它是葉子節點。
         這個參數已經被棄用。用min_impurity_decrease代替了min_impurity_split。
12.class_weight:dict, list of dicts, "balanced" or None, default=None
         類別權重的形式爲{class_label: weight}
         (1).如果沒有給出每個類別的權重,則每個類別的權重都爲1。
         (2).如果class_weight='balanced',則分類的權重與樣本中每個類別出現的頻率成反比。
         計算公式爲:n_samples / (n_classes * np.bincount(y))
         (3).如果sample_weight提供了樣本權重(由fit方法提供),則這些權重都會乘以sample_weight。
13.presort:bool, optional (default=False)
        指定是否需要提前排序數據從而加速訓練中尋找最優切分的過程。設置爲True時,對於大數據集
        會減慢總體的訓練過程;但是對於一個小數據集或者設定了最大深度的情況下,會加速訓練過程。
屬性:
1.classes_:array of shape = [n_classes] or a list of such arrays
        類別的標籤值。
2.feature_importances_ : array of shape = [n_features]
        特徵重要性。越高,特徵越重要。
        特徵的重要性爲該特徵導致的評價準則的(標準化的)總減少量。它也被稱爲基尼的重要性
3.max_features_ : int
        max_features的推斷值。
4.n_classes_ : int or list
        類別的數量
5.n_features_ : int
        執行fit後,特徵的數量
6.n_outputs_ : int
        執行fit後,輸出的數量
7.tree_ : Tree object
        樹對象,即底層的決策樹。
方法:
1.fit(X,y):訓練模型。
2.predict(X):預測
3.predict_log_poba(X):預測X爲各個類別的概率對數值。
4.predict_proba(X):預測X爲各個類別的概率值。


# 迴歸決策樹
from sklearn.tree import DecisionTreeRegressor
DecisionTreeRegressor(criterion="mse",
                         splitter="best",
                         max_depth=None,
                         min_samples_split=2,
                         min_samples_leaf=1,
                         min_weight_fraction_leaf=0.,
                         max_features=None,
                         random_state=None,
                         max_leaf_nodes=None,
                         min_impurity_decrease=0.,
                         min_impurity_split=None,
                         presort=False)
參數含義:
1.criterion:string, optional (default="mse")
            它指定了切分質量的評價準則。默認爲'mse'(mean squared error)。
2.splitter:string, optional (default="best")
            它指定了在每個節點切分的策略。有兩種切分策咯:
            (1).splitter='best':表示選擇最優的切分特徵和切分點。
            (2).splitter='random':表示隨機切分。
3.max_depth:int or None, optional (default=None)
             指定樹的最大深度。如果爲None,則表示樹的深度不限,直到
             每個葉子都是純淨的,即葉節點中所有樣本都屬於同一個類別,
             或者葉子節點中包含小於min_samples_split個樣本。
4.min_samples_split:int, float, optional (default=2)
             整數或者浮點數,默認爲2。它指定了分裂一個內部節點(非葉子節點)
             需要的最小樣本數。如果爲浮點數(0到1之間),最少樣本分割數爲ceil(min_samples_split * n_samples)
5.min_samples_leaf:int, float, optional (default=1)
             整數或者浮點數,默認爲1。它指定了每個葉子節點包含的最少樣本數。
             如果爲浮點數(0到1之間),每個葉子節點包含的最少樣本數爲ceil(min_samples_leaf * n_samples)
6.min_weight_fraction_leaf:float, optional (default=0.)
             它指定了葉子節點中樣本的最小權重係數。默認情況下樣本有相同的權重。
7.max_feature:int, float, string or None, optional (default=None)
             可以是整數,浮點數,字符串或者None。默認爲None。
             (1).如果是整數,則每次節點分裂只考慮max_feature個特徵。
             (2).如果是浮點數(0到1之間),則每次分裂節點的時候只考慮int(max_features * n_features)個特徵。
             (3).如果是字符串'auto',max_features=n_features。
             (4).如果是字符串'sqrt',max_features=sqrt(n_features)。
             (5).如果是字符串'log2',max_features=log2(n_features)。
             (6).如果是None,max_feature=n_feature。
8.random_state:int, RandomState instance or None, optional (default=None)
             (1).如果爲整數,則它指定了隨機數生成器的種子。
             (2).如果爲RandomState實例,則指定了隨機數生成器。
             (3).如果爲None,則使用默認的隨機數生成器。
9.max_leaf_nodes:int or None, optional (default=None)
             (1).如果爲None,則葉子節點數量不限。
             (2).如果不爲None,則max_depth被忽略。
10.min_impurity_decrease:float, optional (default=0.)
             如果節點的分裂導致不純度的減少(分裂後樣本比分裂前更加純淨)大於或等於min_impurity_decrease,則分裂該節點。
             個人理解這個參數應該是針對分類問題時纔有意義。這裏的不純度應該是指基尼指數。
             迴歸生成樹採用的是平方誤差最小化策略。分類生成樹採用的是基尼指數最小化策略。
             加權不純度的減少量計算公式爲:
             min_impurity_decrease=N_t / N * (impurity - N_t_R / N_t * right_impurity
                                - N_t_L / N_t * left_impurity)
             其中N是樣本的總數,N_t是當前節點的樣本數,N_t_L是分裂後左子節點的樣本數,
             N_t_R是分裂後右子節點的樣本數。impurity指當前節點的基尼指數,right_impurity指
             分裂後右子節點的基尼指數。left_impurity指分裂後左子節點的基尼指數。
11.min_impurity_split:float
             樹生長過程中早停止的閾值。如果當前節點的不純度高於閾值,節點將分裂,否則它是葉子節點。
             這個參數已經被棄用。用min_impurity_decrease代替了min_impurity_split。
12.presort: bool, optional (default=False)
             指定是否需要提前排序數據從而加速尋找最優切分的過程。設置爲True時,對於大數據集
             會減慢總體的訓練過程;但是對於一個小數據集或者設定了最大深度的情況下,會加速訓練過程。
屬性:
1.feature_importances_ : array of shape = [n_features]
             特徵重要性。該值越高,該特徵越重要。
             特徵的重要性爲該特徵導致的評價準則的(標準化的)總減少量。它也被稱爲基尼的重要性
2.max_feature_:int
             max_features推斷值。
3.n_features_:int
             執行fit的時候,特徵的數量。
4.n_outputs_ : int
             執行fit的時候,輸出的數量。
5.tree_ : 底層的Tree對象。
Notes:
控制樹大小的參數的默認值(例如``max_depth``,``min_samples_leaf``等)導致完全成長和未剪枝的樹,
這些樹在某些數據集上可能表現很好。爲減少內存消耗,應通過設置這些參數值來控制樹的複雜度和大小。
方法:
1.fit(X,y):訓練模型。
2.predict(X):預測。
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章