Mnist數據集

1. mnist數據集下載

百度雲鏈接,提取碼:0wdr

2.

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/mnist/", one_hot = True)

# Load data
x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels

print("x_train: ", x_train.shape)
print("y_train: ", y_train.shape)
print("x_test: ", x_test.shape)
print("y_test: ", y_test.shape)

def plot_mnist(data, classes):
    
    for i in range(10):
        idxs = (classes == i)
        
        # get 10 images for class i
        images = data[idxs][0:10]
            
        for j in range(5):   
            plt.subplot(5, 10, i + j*10 + 1)
            plt.imshow(images[j].reshape(28, 28), cmap='gray')
            # print a title only once for each class
            if j == 0:
                plt.title(i)
            plt.axis('off')
    plt.show()

classes = np.argmax(y_train, 1)
plot_mnist(x_train, classes)

運行結果:

Extracting /mnist/train-images-idx3-ubyte.gz
Extracting /mnist/train-labels-idx1-ubyte.gz
Extracting /mnist/t10k-images-idx3-ubyte.gz
Extracting /mnist/t10k-labels-idx1-ubyte.gz
x_train: (55000, 784)
y_train: (55000, 10)
x_test: (10000, 784)
y_test: (10000, 10)
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章