Theano入門——MNIST數據庫
1. MNIST數據庫的文件格式
train-images-idx3-ubyte: 訓練圖像集
train-labels-idx1-ubyte: 訓練標籤集
t10k-images-idx3-ubyte: 測試圖像集
t10k-labels-idx1-ubyte: 測試標籤集
訓練集包含60000個例子,測試集包含10000個例子。[偏移] [類型] [數值] [描述]
0000 32位整數 0x00000801(2049) 幻數(MSB first)
0004 32位整數 60000 樣本總數
0008 unsigned byte ?? 標籤
0009 unsigned byte ?? 標籤
........
xxxx unsigned byte ?? 標籤
[偏移] [類型] [數值] [描述]
0000 32位整數 0x00000803(2051) 幻數
0004 32位整數 60000 圖像總數
0008 32位整數 28 圖像行數
0012 32位整數 28 圖像列數
0016 unsigned byte ?? 像素
0017 unsigned byte ?? 像素
........
xxxx unsigned byte ?? 像素
像素逐行組織。像素值爲0到255。0爲背景(白色),255爲前景(黑色)。[偏移] [類型] [數值] [描述]
0000 32位整數 0x00000801(2049) 幻數(MSB first)
0004 32位整數 10000 樣本總數
0008 unsigned byte ?? 標籤
0009 unsigned byte ?? 標籤
........
xxxx unsigned byte ?? 標籤
標籤值爲0到9。
[偏移] [類型] [數值] [描述]
0000 32位整數 0x00000803(2051) 幻數
0004 32位整數 10000 圖像總數
0008 32位整數 28 圖像行數
0012 32位整數 28 圖像列數
0016 unsigned byte ?? 像素
0017 unsigned byte ?? 像素
........
xxxx unsigned byte ?? 像素
像素逐行組織。像素值爲0到255。0爲背景(白色),255爲前景(黑色)。2.IDX文件格式
幻數
0維大小
1維大小
2維大小
.....
N維大小
數據
幻數爲1個整數(MSB first),它的前兩個字節總是0。
第3個字節碼爲數據類型:
0x08: unsigned byte
0x09: signed byte
0x0B: short (2 bytes)
0x0C: int (4 bytes)
0x0D: float (4 bytes)
0x0E: double (8 bytes)
第4個字節碼爲向量或矩陣的維數:1爲向量,2爲矩陣,...
每維的大小爲4字節整數(MSB first,high endian)。
3.加載數據庫
(1)MNIST數據庫文件夾mnist存放在相對文件路徑datasets_dir下。
(2)mnist函數
訓練圖像文件的16個字節後爲像素數據(16字節前爲幻數和數據各維的大小),將逐行排序的像素變形爲行數爲60000,列數爲28*28的矩陣trX。最後像素的數據類型定爲浮點型;訓練標籤文件的8個字節後爲標籤數據,將逐行排序的標籤變形爲行數爲60000,列數爲1的矩陣trY。最後標籤的數據類型定爲無符號整型。測試圖像文件和測試標籤文件同理,只是分別產生的矩陣teX和teY的行數由60000變爲10000。
像素值爲0到255,所以將像素值歸一化到[0,1]區間內。
選擇前ntrain個訓練圖像例子和標籤和前ntest個測試圖像例子和標籤。
如果onehot爲真,則執行onehot函數。
(3)one_hot函數
該函數的輸入有x和n,x爲標籤,n爲標籤編碼的範圍跨度。MNIST的數據標籤的值爲0~9,範圍跨度爲10。首先設置長度爲標籤長度(即數據例子的個數),列數爲標籤的範圍跨度的矩陣數組o_h,然後用np.arange(len(x))確定置1的行,用x確定置1的列。範圍跨度爲10的one-hot編碼形式舉例:
(4)圖像樣本顯示
選擇第6個圖像樣本和第256個圖像樣本作顯示(Xtrain數組索引從0開始)。
4. 示例代碼
import numpy as np
import os
import matplotlib.pyplot as plt
datasets_dir = 'media/datasets/'
def one_hot(x,n):
if type(x) == list:
x = np.array(x)
x = x.flatten()
o_h = np.zeros((len(x),n))
o_h[np.arange(len(x)),x] = 1
return o_h
def mnist(ntrain=60000,ntest=10000,onehot=True):
data_dir = os.path.join(datasets_dir,'mnist/')
fd = open(os.path.join(data_dir,'train-images.idx3-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
trX = loaded[16:].reshape((60000,28*28)).astype(float)
fd = open(os.path.join(data_dir,'train-labels.idx1-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
trY = loaded[8:].reshape((60000))
fd = open(os.path.join(data_dir,'t10k-images.idx3-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
teX = loaded[16:].reshape((10000,28*28)).astype(float)
fd = open(os.path.join(data_dir,'t10k-labels.idx1-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
teY = loaded[8:].reshape((10000))
trX = trX/255.
teX = teX/255.
trX = trX[:ntrain]
trY = trY[:ntrain]
teX = teX[:ntest]
teY = teY[:ntest]
if onehot:
trY = one_hot(trY, 10)
teY = one_hot(teY, 10)
else:
trY = np.asarray(trY)
teY = np.asarray(teY)
return trX,teX,trY,teY
Xtrain, Xtest, Ytrain, Ytest = mnist()
################################################
# 數字樣本顯示
image = Xtrain[5].reshape(28,28)
image1 = Xtrain[255].reshape(28,28)
fig = plt.figure()
ax = fig.add_subplot(121)
ax.xaxis.set_ticklabels([])
ax.yaxis.set_ticklabels([])
plt.imshow(image, cmap='gray')
ax = fig.add_subplot(122)
ax.xaxis.set_ticklabels([])
ax.yaxis.set_ticklabels([])
plt.imshow(image1, cmap='gray')
plt.show()
5. 實驗結果
6.參考鏈接
(1)數據庫下載:http://yann.lecun.com/exdb/mnist/
(2)數據庫加載:https://github.com/Newmu/Theano-Tutorials/blob/master/load.py
(3)圖像顯示:http://www.cnblogs.com/x1957/archive/2012/06/02/2531503.html