基於tensorflow開發框架,搭建softmax模型完成mnist分類任務。
本文的完整代碼託管在我的Github PnYuan - Practice-of-Machine-Learning - MNIST_tensorflow_demo,歡迎訪問。
1.任務背景
1.1.目的
以MNIST手寫數字識別爲課題,研究基本深度學習方法的應用。本文先從Softmax模型切入,以熟悉tensorflow下mnist任務的開發流程。之後的文章將陸續引入MLP、CNN等模型,以達到更優異的識別效果。
1.2.數據集
MNIST任務是圖像識別領域經典的“Helloworld”。在其所提供的數據集中,包含了6w個訓練樣本和1w個測試樣本,均爲黑白圖片,大小28×28,以灰度矩陣的形式存放,數值取浮點數“0~1”對應“白~黑”。給出一些圖片(X)及對應標註(Y)如下圖所示:
2.實驗過程
2.1.數據預研
MNIST數據的一些基本信息如下:
輸入:image - 784 的向量 --> image size [28*28]
輸出:label - int(0-10)
#train:55k
#valid:5k
#test:10k
基於tensorflow對mnist數據進行加載與測試的示例代碼如下:
mnist = input_data.read_data_sets('../data/mnist_data/',one_hot=True)
X_train_org, Y_train_org = mnist.train.images, mnist.train.labels
X_valid_org, Y_valid_org = mnist.validation.images, mnist.validation.labels
X_test_org, Y_test_org = mnist.test.images, mnist.test.labels
# check the shape of dataset
print("train set shape: X-", X_train_org.shape, ", Y-", Y_train_org.shape)
print("valid set shape: X-", X_valid_org.shape, ", Y-", Y_valid_org.shape)
print("test set shape: X-", X_test_org.shape, ", Y-", Y_test_org.shape)
2.2.Softmax建模
Softmax迴歸可看作是Logistic迴歸模型向多分類任務的拓展,其模型可描述如下圖:
其公式表達如下:
寫成向量化形式:
權值 W 和偏置 b 是這裏需要學習的參數。
採用tensorflow可以輕鬆構建出Softmax模型,示例代碼如下:
#========== Softmax Modeling ==========#
x = tf.placeholder("float", [None, 784]) # placeholder of input
W = tf.Variable(tf.zeros([784,10])) # parameters (initial to 0)
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b) # softmax computation graph
2.3.訓練與測試
通過構建tensorflow對話(session),給定輸入x
,運行由x-->y
的計算圖(Computation Graph),後臺可簡單完成訓練過程。這裏採用的是簡單的mini-batch Gradient Descent
優化策略。
模型的訓練樣例代碼如下:
y_ = tf.placeholder("float", [None, 10]) # placeholder of label
cross_entropy = -tf.reduce_sum(y_*tf.log(y)) # loss (cross-entropy)
train_step = tf.train.GradientDescentOptimizer(learning_rate = 0.01).minimize(cross_entropy) # using GD
#========== Training ==========#
init = tf.global_variables_initializer()
sess = tf.InteractiveSession() # initial a session
sess.run(init)
for i in range(1000): # iterate for 100 times
batch_xs, batch_ys = mnist.train.next_batch(100) # using mini-batch
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
在驗證集與測試集上評估所學模型的效果,以預測準確率(accuracy)爲指標,得出結果如下:
valid accuracy 0.927
test accuracy 0.9208
可以看出,Softmax模型在經過一定時間的訓練之後,達到了九成的分類準確率。與MNIST官網給出的線性分類器(單層NN)的準確級別相近。
3.實驗小結
這裏採用tensorflow開發框架搭建了Softmax多分類模型,實現了超過90%的測試準確率。模型的搭建以及訓練測試過程十分簡便。據tensorflow官網所述,使用多層神經網絡等更復雜的模型還可進一步提升分類效果,接下來的文章,將對此進行跟進。