Mnist數據集
1. mnist數據集下載
百度雲鏈接,提取碼:0wdr
2.
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/mnist/", one_hot = True)
# Load data
x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels
print("x_train: ", x_train.shape)
print("y_train: ", y_train.shape)
print("x_test: ", x_test.shape)
print("y_test: ", y_test.shape)
def plot_mnist(data, classes):
for i in range(10):
idxs = (classes == i)
# get 10 images for class i
images = data[idxs][0:10]
for j in range(5):
plt.subplot(5, 10, i + j*10 + 1)
plt.imshow(images[j].reshape(28, 28), cmap='gray')
# print a title only once for each class
if j == 0:
plt.title(i)
plt.axis('off')
plt.show()
classes = np.argmax(y_train, 1)
plot_mnist(x_train, classes)
運行結果:
Extracting /mnist/train-images-idx3-ubyte.gz
Extracting /mnist/train-labels-idx1-ubyte.gz
Extracting /mnist/t10k-images-idx3-ubyte.gz
Extracting /mnist/t10k-labels-idx1-ubyte.gz
x_train: (55000, 784)
y_train: (55000, 10)
x_test: (10000, 784)
y_test: (10000, 10)