mnist數據處理

原文出處:https://blog.csdn.net/simple_the_best/article/details/75267863

 

 

import os
import struct
import numpy as np
import matplotlib.pyplot as plt

1.下載數據

MNIST 數據集可在 http://yann.lecun.com/exdb/mnist/ 獲取, 它包含了四個部分:

Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解壓後 47 MB, 包含 60,000 個樣本)
Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解壓後 60 KB, 包含 60,000 個標籤)
Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解壓後 7.8 MB, 包含 10,000 個樣本)
Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解壓後 10 KB, 包含 10,000 個標籤)
 

2.下載數據 定義函數讀取數據

def load_mnist(path, kind='train'):
    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,   '%s-labels.idx1-ubyte'     % kind)#注意文件名要和下載的一致
    images_path = os.path.join(path,   '%s-images.idx3-ubyte'     % kind)
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II',   lbpath.read(8))
        labels = np.fromfile(lbpath, dtype=np.uint8)
    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
        images = np.fromfile(imgpath,  dtype=np.uint8).reshape(len(labels), 784)

    return images, labels#返回數組
(train_img,train_lab)=load_mnist('D:\minist\data',kind='train') 

print(train_img.shape)
print(train_lab.shape)

fig, ax = plt.subplots(
    nrows=2,
    ncols=5,
    sharex=True,
    sharey=True, )

ax = ax.flatten()
for i in range(10):
    img = train_img[i].reshape(28, 28)
    ax[i].imshow(img, cmap='Greys', interpolation='nearest')

ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()
 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章