sklearn 決策樹實例——泰坦尼克號生存率預測

題目地址:Titanic: Machine Learning from Disaster
相關的測試集和訓練集我在excel進行了稍微的調整,鏈接:提取碼: mdwi
分數不高,但是用的都是最基本的知識

from sklearn.tree import DecisionTreeClassifier
import pandas as pd
import warnings
from sklearn import tree  # 導入決策樹
import graphviz
"""
打印處理
"""
# 不顯示warning
warnings.filterwarnings('ignore')
# 顯示所有列
pd.set_option('display.max_columns', None)
# 顯示所有行
pd.set_option('display.max_rows', None)
# 顯示寬度爲1000
pd.set_option('display.width', 1000)


def prep(csvfile):
    """
    數據預處理
    """
    df = pd.read_csv(csvfile)
    # Embarked 1 = Cherbourg, 2 = Queenstown, 3 = Southampton
    target1, target2, target3 = df['Embarked'] == 'C', df['Embarked'] == 'Q', df['Embarked'] == 'S'
    df['Embarked'][target1], df['Embarked'][target2], df['Embarked'][target3] = 1, 2, 3
    # Sex 1 = male 0 = female
    target1, target2 = df['Sex'] == 'male', df['Sex'] == 'female'
    df['Sex'][target1], df['Sex'][target2] = 1, 0
    # 處理nan數據
    df['Age'][df['Age'].isna()] = df['Age'].mean()
    df['Embarked'][df['Embarked'].isna()] = 3
    df['Fare'][df['Fare'].isna()] = df['Fare'].mean()
    # 打印預覽
    # print(df.head(5))
    return df


"""
獲取訓練集、測試集
"""
# 訓練集
train = prep('data.csv')
Y_train = train['Survived'].values
X_train = train.drop('Survived', axis=1).values
# 測試集
X_test = prep('test.csv')

"""
實例化&訓練決策樹
"""
dtc = DecisionTreeClassifier(splitter='random', max_depth=10).fit(X_train, Y_train)
# 打分
print('訓練集的分數:{}分'.format(round(100 * dtc.score(X_train, Y_train), 1)))
# 查看重要度
print(*zip([column for column in (train.drop('Survived', axis=1))],
           dtc.feature_importances_
           )
      )
# 測試
result = pd.DataFrame(dtc.predict(X_test), columns=['Survived'])
result.index = result.index + 892
# 格式化輸出測試結果
result.to_csv('Titanic_result.csv',
              index_label='PassengerId'
              )

# Create Tree Picture
dot_data = tree.export_graphviz(dtc,
                                filled=True,
                                rounded=True,
                                out_file=None,
                                )
graph = graphviz.Source(dot_data)
graph.render()

結果:
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章