Tensorflow 讀取 MNIST

STEP1. 下載MNIST訓練數據集

方法一:使用tensorflow下載

from tensorflow.examples.tutorials.mnist import input_data
# 下載mnist數據集
mnist = input_data.read_data_sets('/mydata/', one_hot=True)

當前文件位置就會出現mydata文件夾,裏面有4個壓縮文件

方法二:手動下載        下載地址:http://yann.lecun.com/exdb/mnist/ 

將得到的壓縮文件解壓,

         

STEP2. 讀取圖片

用法一:

 MNIST 網站上對數據集的介紹:

以Training set image爲例,讀取時需要跳過 4個integer 類型        而 Training set label 需要跳過2個integer 類型

讀取 Training set 的 image 和 label 文件 

def readfile():
    with open('data/MNIST_data/train-images.idx3-ubyte','rb') as f:
        train_image = f.read()
    with open('data/MNIST_data/train-labels.idx1-ubyte', 'rb') as f:
        train_labels = f.read()
    return train_image,train_labels

讀取第一張圖並顯示

image,label=readfile()
index = struct.calcsize('>IIII')    # I代表一個無符號整數 ,跳過四個
temp = struct.unpack_from('>784B', image, index) #MNIST中的圖片都是28*28的,所以讀取784bit
img=np.reshape(temp, (28, 28))     
plt.imshow(img,cmap='gray')
plt.show()

寫了個讀取前 n 張圖片和標籤的

import tensorflow as tf
import struct
import matplotlib.pyplot as plt
import numpy as np

def readfile():
    with open('data/MNIST_data/train-images.idx3-ubyte','rb') as f:
        train_image = f.read()
    with open('data/MNIST_data/train-labels.idx1-ubyte', 'rb') as f:
        train_labels = f.read()
    return train_image,train_labels
'''
讀取前n張圖片
'''
def get_images(buf,n):
    im=[]
    index = struct.calcsize('>IIII')
    for i in range(n):
        temp = struct.unpack_from('>784B', buf, index)
        im.append(np.reshape(temp, (28, 28)))
        index += struct.calcsize('>784B')
    return im
'''
讀取前n個標籤
'''
def get_labels(buf,n):
    l=[]
    index = struct.calcsize('>II')
    for i in range(n):
        temp = struct.unpack_from('>1B', buf, index)
        l.append(temp[0])
        index += struct.calcsize('>1B')
    return l

'''
讀取
'''
image,label=readfile()
n=16
train_img=get_images(image,n)
train_label=get_labels(label,n)

'''
顯示
'''
for i in range(16):
    plt.subplot(4,4,1+i)
    title = u"label:" + str(train_label[i])
    plt.title(title)
    plt.imshow(train_img[i],cmap='gray')
plt.show()

用法二:

原始的 MNIST 數據集有6000張訓練圖片,1000張驗證圖片,而 Tensorflow 又將訓練圖片劃分成 5500 張訓練圖片和 500 張驗證圖片

from tensorflow.examples.tutorials.mnist import input_data
# 下載mnist數據集
mnist = input_data.read_data_sets('/mydata/', one_hot=True)
'''
查看一下各個變量的大小
'''
print(mnist.train.images.shape)       #(55000, 784)   訓練圖像 55000個樣本
print(mnist.train.labels.shape)       #(55000, 10)    訓練標籤
print(mnist.validation.images.shape)  #(5000, 784)    驗證圖像 55000個樣本
print(mnist.validation.labels.shape)  #(5000, 10)     驗證標籤
print(mnist.test.images.shape)        #(10000, 784)   測試圖像 55000個樣本
print(mnist.test.labels.shape)        #(10000, 10)    測試標籤

讀取一張 mnist.train.images 裏的圖片

temp = mnist.train.images[0]
img=np.reshape(temp, (28, 28))
plt.imshow(img,cmap='gray')
plt.show()

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章