這次我們使用tensorflow來區分iris
python代碼
# coding=utf-8
from sklearn import metrics,model_selection
import tensorflow as tf
from tensorflow.contrib import learn
# 獲取鳶尾數據
iris = learn.datasets.load_dataset('iris')
X_train,X_test,y_train,y_test = model_selection.train_test_split(iris.data,iris.target,test_size=.5,random_state=42)
# print iris
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
print feature_columns
print tf.contrib.layers
clf = learn.DNNClassifier(feature_columns=feature_columns, hidden_units=[10,20,10],n_classes=3,model_dir="/tmp/iris_model")
clf.fit(x=X_train,y=y_train,steps=2000)
predictions = clf.predict(x=X_test)
# 評分
print clf.evaluate(x=X_test,y=y_test)["accuracy"]
匹配(大概值)
0.96匹配