在我們學習任何一門編程語言的時候,我們做的第一件事情就是寫一個“Hello World!”程序;機器學習的“Hello World!”就是MNIST。MNIST是一個簡單的計算機視覺數據集,它由下述的手寫阿拉伯數字圖像構成:
MNIST同時也包括這些圖像的標註,比如上述圖像的標註分別是5,0,4和1。
在這篇教程中,我們會訓練一個模型去預測圖像是什麼數字。我們當前的目標並不是精心訓練一個業內領先的模型,而是讓你先簡單瞭解一下TensorFlow;但我們後面會爲你提供如何訓練一個業內領先模型的代碼!首先,我們從一個非常簡單的、叫做Softmax迴歸的模型開始。
這個教程的實際代碼非常短,所有有趣的內容3行左右就夠了。然而,對於我們最重要的是理解這些代碼背後的含義:TensorFlow和核心機器學習理念都是如何工作的?所以,下面我們會非常仔細的講解這些代碼。
在Yann Lecun的網站上可找到MNIST數據。爲方便起見,我們放置了自動下載並安裝MNIST數據的Python代碼。你可以象下面一樣下載並import這些代碼,或者簡單的剪刀加漿糊。
- from tensorflow.examples.tutorials.mnistimport input_data
- mnist =input_data.read_data_sets("MNIST_data/", one_hot=True)
from tensorflow.examples.tutorials.mnistimport input_data
mnist =input_data.read_data_sets("MNIST_data/", one_hot=True)
下載的數據可分成三部分:55000份訓練數據(mnist.train)、10000份測試數據(mnist.test)和5000份開發數據(mnist.validation)。這個劃分非常重要:我們已經分離出不用來訓練的數據,這樣可以確保我們的學習具有實際的推廣性能。
之前提到過,每份MNIST數據包含手寫數字圖像(後記爲xs)和對應標註(後記爲ys)2部分。訓練集和測試集均包含xs和ys,所以訓練圖像是mnist.train.images、訓練標註是mnist.train.labels.
每份圖像寬和高均爲28像素點,我們可以理解爲一個大數組。
我們可以把這個數字平鋪爲一個28x28 = 784個點的矢量,只要我們保持一致,怎麼平鋪都無所謂。從這個角度來說,這些MNIST圖像僅僅是784維矢量空間中的一些有豐富結構的點。
平鋪這些數據導致了我們丟失了圖像二維結構信息,這樣做有什麼壞處呢?好吧,後面會介紹的最好的計算機視覺方法會使用這個二維結構信息,但這兒用的Softmax迴歸方法不會使用。
結果就是mnist.train.images是一個形狀爲[55000, 784]的張量(Tensor),第一維是圖像數目、第二維是每幅圖像的像素點。Tensor裏的每個值對應某幅圖像的某個像素點的亮度值,介於0和1之間。
MNIST對應的標註是0到9的數字。我們這份教程中,會把標註設定爲One-hot矢量,One-hot矢量的特點是絕大多數維的值都是0、只有某一維的值是1。這意味着數值是幾,第幾維就是1,如標註3對應的是[0,0,0,1,0,0,0,0,0,0]。這樣,mnist.train.labels是個[55000, 10]的浮點數組。
現在我們準備好實際構建我們的模型啦!