隨機森林
隨機森林的基本知識,許多博客都已經講的很清晰。 但是用 tensorflow 實現的代碼卻很少。
博主屬於機器學習未入門級別,純屬分享一段可以運行的隨機森林代碼。若有錯誤,麻煩指出。
本代碼實現用 mnist 示例,通過 tensorflow 實現隨機森林。
import tensorflow as tf
import pandas as pd
import numpy as np
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.ops import resources
dataset = pd.read_csv('./data/train.csv').values
####################
# 讀入數據進行數據處理
# 最終 X.shape 爲 [?,784]
# y.shape 爲 [?],y 必須爲數字標籤,非 onehot標籤。
#
# X 需要歸一化處理
####################
# paremeters
num_steps = 5000 # 迭代次數
batch_size = 1000 # 每個處理批次的大小
num_classes = 10 # 類別的數目
num_features = 784 # 特徵的數目
num_trees = 50 # 森林裏樹的個數
max_nodes = 800 # 每棵數的最大節點數目
# 定義 X 和 y
X = tf.placeholder(tf.float32, shape = [None, 784])
y = tf.placeholder(tf.int32, shape = [None])
# 定義隨機森林的內部參數
hparams = tensor_forest.ForestHParams(num_classes = num_classes, num_features = num_features, num_trees = num_trees, max_nodes = max_nodes).fill()
# 創建隨機森林運算圖
forest_graph = tensor_forest.RandomForestGraphs(hparams)
# 定義訓練運算圖和 loss
train_op = forest_graph.training_graph(X, y)
loss_op = forest_graph.training_loss(X, y)
# 計算準確率
infer_op, _, _ = forest_graph.inference_graph(X)
correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(y, tf.int64))
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 定義圖的初始化內容
init_var = tf.group(tf.global_variables_initializer(), resources.initialize_resources(resources.shared_resources()))
# 初始化
sess = tf.Session()
sess.run(init_var)
# 定義批次爲隨機 size 個訓練數據
def next_batch(X_data, y_data, size):
rand = np.arange(y_data.shape[0])
np.random.shuffle(rand)
return X_data[rand[0 : size]], y_data[rand[0 : size]]
# 訓練過程
for i in range(1, num_steps + 1):
batch = next_batch(X_train, y_train, batch_size)
_, cost = sess.run([train_op, loss_op], feed_dict = {X: batch[0], y:batch[1]})
if i % 100 == 0 or i == 1:
train_acc = sess.run(accuracy_op, feed_dict={X: batch[0], y: batch[1]})
val_acc = sess.run(accuracy_op, feed_dict={X: X_val, y: y_val})
print('Step %i, Loss: %f, Train_Acc: %s, Test_acc: %s' % (i, cost, train_acc, val_acc))
print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X:X_test y:y_test}))