決策樹python

1.準備(linux):
(1)sudo apt-get install graphviz
(2)sudo pip install graphviz
(3)sudo pip install pydotplus

2.評判標準(criterion):
可以選基尼係數或者信息增益熵
criterion = ‘gini’
criterion = ‘entropy’

3代碼(kaggle中NBA球員的位置進行分類(控球后衛,中鋒),特徵包括2分球,與助攻):

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import pydotplus

from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import LabelEncoder

cBackground = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF'])
cPos = mpl.colors.ListedColormap(['g','r','b'])
positions = 'SF'
positionsNSF = 'PG', 'C'
sF = 'AST', '2P'

players = pd.read_csv('Players.csv')
seasionStat = pd.read_csv('Seasons_Stats.csv')
totalPos, totalAST, total2P = seasionStat['Pos'], seasionStat[sF[0]], seasionStat[sF[1]]


pos, analysis = [],[]

for ind, text in enumerate(totalPos):
    if text in positionsNSF:
        pos.append(text)
        analysis.append([totalAST[ind], total2P[ind]])


labels = LabelEncoder().fit_transform(pos)
xTrain, xTest, labelTrain, labelTest = train_test_split(analysis, labels, train_size = 0.7)

model = DecisionTreeClassifier(criterion = 'gini')
model.fit(xTrain, labelTrain)
testHat = model.predict(xTest)
print 'accuracy: ', accuracy_score(testHat, labelTest)

chunks = 50
analysis = np.array(analysis).T
astMin, pstMin = min(analysis[0]), min(analysis[1])
astMax, pstMax = max(analysis[0]), max(analysis[1])

astAxis, pstAxis = np.linspace(astMin, astMax, chunks), np.linspace(pstMin, pstMax, chunks)
xGrid, yGrid = np.meshgrid(astAxis, pstAxis)
xyStack = np.stack((xGrid.flat,yGrid.flat), axis = 1)
yHat = model.predict(xyStack).reshape(xGrid.shape)
xTest = np.array(xTest).T

plt.pcolormesh(xGrid, yGrid, yHat, cmap = cBackground)
plt.scatter(xTest[0], xTest[1], c = testHat.ravel(), s = 40, cmap = cPos)
plt.xlim(astMin, astMax)
plt.ylim(pstMin, pstMax)
plt.xlabel(sF[0])
plt.ylabel(sF[1])
plt.title('PG and C')
plt.grid()
plt.show()

analysis = []
for ind, text in enumerate(totalPos):
    if text == 'SF':
        analysis.append([totalAST[ind], total2P[ind]])

analysis = np.array(analysis).T
astMin, pstMin = min(analysis[0]), min(analysis[1])
astMax, pstMax = max(analysis[0]), max(analysis[1])
plt.scatter(analysis[0], analysis[1], c = 'r')
plt.grid()
plt.xlim(astMin, astMax)
plt.ylim(pstMin, pstMax)
plt.xlabel(sF[0])
plt.ylabel(sF[1])
plt.title('SF')
plt.show()

4.結果示意圖(測試數據,準確率爲88%):
這裏寫圖片描述
5.這裏面也畫了小前鋒(SF的數據:
這裏寫圖片描述
這裏可以看出小前鋒的數據比較雜亂,這也表明了小前鋒比較全能。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章