參考資料:
機器學習實戰
'''
@version: 0.0.1
@Author: Huang
@dev: python3 vscode
@Date: 2019-11-07 23:59:30
@LastEditTime: 2019-11-08 14:16:04
@FilePath: \\機器學習實戰\\09-樹迴歸\\tree.py
@Descripttion: CART是十分著名且廣泛記載的樹構建算法,它使用二元切分來處理連續型變量
'''
import numpy as np
import matplotlib.pyplot as plt
def loadDataSet(fileName):
"""
[summary]:加載數據
"""
dataMat = []
fp = open(fileName)
for line in fp.readlines():
curLine = line.strip().split('\t')
fltLine = list(map(float, curLine))
dataMat.append(fltLine)
return dataMat
def plotDataSet(fileName):
"""
[summary]:可視化數據集
Arguments:
fileName {[str]} -- 文件名
"""
dataMat = loadDataSet(fileName)
n = len(dataMat)
xcord, ycord = [], []
for i in range(n):
xcord.append(dataMat[i][0])
ycord.append(dataMat[i][1])
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(xcord, ycord, s=20, c='blue', alpha=0.5)
plt.title('DataSet')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()
def binSplitDataSet(dataSet, feature, value):
"""
[summary]:切分數據集合
Arguments:
dataSet {[type]} -- 數據集合
feature {[type]} -- 待切分特徵
value {[type]} -- 閾值
Returns:
mat0 -- 大於特徵的切分子集0
mat1 -- 小於等於特徵的切分子集1
"""
mat0 = dataSet[np.nonzero(dataSet[:, feature] > value)[0], :]
mat1 = dataSet[np.nonzero(dataSet[:, feature] <= value)[0], :]
return mat0, mat1
def regLeaf(dataSet):
"""
[summary]:生成葉節點
Arguments:
dataSet -- 數據集合
Returns:
目標變量的均值
"""
return np.mean(dataSet[:, -1])
def regErr(dataSet):
"""
[summary]:計算總方差
"""
return np.var(dataSet[:, -1]) * np.shape(dataSet)[0]
def linearSolve(dataSet):
"""
[summary]:將數據集格式化成目標變量Y和自變量X
Arguments:
dataSet {[type]} -- [description]
Raises:
NameError: [description]
Returns:
[type] -- [description]
"""
m, n = np.shape(dataSet)
X = np.mat(np.ones((m, n)))
Y = np.mat(np.ones((m, 1)))
X[:, 1:n] = dataSet[:, 0:n - 1]
Y = dataSet[:, -1]
xTx = X.T * X
if np.linalg.det(xTx) == 0.0:
raise NameError('This matrix is singular, cannot do inverse,\n\
try increasing the second value of ops')
ws = xTx.I * (X.T * Y)
return ws, X, Y
def modelLeaf(dataSet):
"""
[summary]:create linear model and return coeficients
"""
ws, X, Y = linearSolve(dataSet)
return ws
def modelErr(dataSet):
"""
[summary]:計算誤差
"""
ws, X, Y = linearSolve(dataSet)
yHat = X * ws
return sum(np.power(Y - yHat, 2))
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
"""
[summary]:找到數據的最佳二元切分方式函數
對每個特徵:
對每個特徵值:
將數據集切分成兩份
計算切分的誤差
如果當前誤差小於當前最小誤差,那麼將當前切分設定爲最佳切分並更新最小誤差
返回最佳切分的特徵和閾值
Arguments:
dataSet {[numpy.matrix]} -- 數據集合
Keyword Arguments:
leafType {[type]} -- 建立葉節點的函數 (default: {regLeaf})
errType {[type]} -- 誤差計算函數 (default: {regErr})
ops {tuple} -- 包含樹構建其他所需參數的元組 (default: {(1, 4)})
Returns:
bestIndex -- 最佳切分特徵
bestValue -- 特徵值
"""
tolS = ops[0]
tolN = ops[1]
if len(set(dataSet[:, -1].T.tolist()[0])) == 1:
return None, leafType(dataSet)
m, n = np.shape(dataSet)
S = errType(dataSet)
bestS = np.inf
bestIndex = 0
bestValue = 0
for featIndex in range(n - 1):
for splitVal in set(dataSet[:, featIndex].T.tolist()[0]):
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
if (S - bestS) < tolS:
return None, leafType(dataSet)
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
return None, leafType(dataSet)
return bestIndex, bestValue
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
"""
[summary]:
找到最佳的待切分特徵:
如果該節點不能再分,將該節點存爲葉節點
執行二元切分
在右子樹調用createTree()方法
在左子樹調用createTree()方法
Arguments:
dataSet {[type]} -- 數據集
Keyword Arguments:
leafType {[type]} -- 建立葉節點的函數 (default: {regLeaf})
errType {[type]} -- 誤差計算函數 (default: {regErr})
ops {tuple} -- 包含樹構建其他所需參數的元組 (default: {(1, 4)})
Returns:
[type] -- [description]
"""
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
if feat == None:
return val
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree
def isTree(obj):
"""
[summary]:判斷當前處理節點是否是葉節點
"""
return (type(obj).__name__ == 'dict')
def getMean(tree):
"""
[summary]:塌陷處理,從上到下遍歷樹,計算找到的兩個葉節點的平均值
"""
if isTree(tree['right']):
tree['right'] = getMean(tree['right'])
if isTree(tree['left']):
tree['left'] = getMean(tree['left'])
return (tree['left'] + tree['right']) / 2.0
def prune(tree, testData):
"""
[summary]:後剪枝
基於已有的樹切分測試數據:
如果存在任一子集是一棵樹,則在該子集遞歸剪枝過程
計算將當前兩個葉節點合併後的誤差
計算不合並的誤差
如果合併會降低誤差的話,就將葉節點合併
Arguments:
tree {[type]} -- 待剪枝的樹
testData {[type]} -- 測試集
Returns:
[type] -- 樹的平均值
"""
if np.shape(testData)[0] == 0:
return getMean(tree)
if (isTree(tree['right']) or isTree(tree['left'])):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']):
tree['left'] = prune(tree['left'], lSet)
if isTree(tree['right']):
tree['right'] = prune(tree['right'], rSet)
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
errorNoMerge = sum(np.power(lSet[:, -1] - tree['left'], 2)) +\
sum(np.power(rSet[:, -1] - tree['right'], 2))
treeMean = (tree['left'] + tree['right']) / 2.0
errorMerge = sum(np.power(testData[:, -1] - treeMean, 2))
if errorMerge < errorNoMerge:
print("merging")
return treeMean
else:
return tree
else:
return tree
def regTreeEval(model, inDat):
return float(model)
def modelTreeEval(model, inDat):
n = np.shape(inDat)[1]
X = np.mat(np.ones((1, n + 1)))
X[:, 1:n + 1] = inDat
return float(X * model)
def treeForeCast(tree, inData, modelEval=regTreeEval):
if not isTree(tree):
return modelEval(tree, inData)
if inData[tree['spInd']] > tree['spVal']:
if isTree(tree['left']):
return treeForeCast(tree['left'], inData, modelEval)
else:
return modelEval(tree['left'], inData)
else:
if isTree(tree['right']):
return treeForeCast(tree['right'], inData, modelEval)
else:
return modelEval(tree['right'], inData)
def createForeCast(tree, testData, modelEval=regTreeEval):
m = len(testData)
yHat = np.mat(np.zeros((m, 1)))
for i in range(m):
yHat[i, 0] = treeForeCast(tree, np.mat(testData[i]), modelEval)
return yHat
if __name__ == '__main__':
myDat = loadDataSet(r'./09-樹迴歸/ex00.txt')
myMat = np.mat(myDat)
print(createTree(myMat))