導入相關包
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
# tensorflow自帶的一些數據集
from tensorflow.examples.tutorials.mnist import input_data
加載數據集
在該目錄下,建立一個空文件夾data
,加載mnist數據集時,會自動從網上下載
print('Download and Extract MNIST dataset')
mnist = input_data.read_data_sets('data/', one_hot=True)
print('type of "mnist" is %s' % (type(mnist)))
print('number of train data is %d' % (mnist.train.num_examples))
print('number of test data is %d' % (mnist.test.num_examples))
mnist數據集的描述信息
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels
print('type of "trainimg" is %s' % (type(trainimg)))
print('type of "trainlabel" is %s' % (type(trainlabel)))
print('type of "testimg" is %s' % (type(testimg)))
print('type of "testlabel" is %s' % (type(testlabel)))
print('shape of "trainimg" is %s' % (trainimg.shape,))
print('shape of "trainlabel" is %s' % (trainlabel.shape,))
print('shape of "testimg" is %s' % (testimg.shape,))
print('shape of "testlabel" is %s' % (testlabel.shape,))
輸出結果:
打印原數據集
nsample = 5
randidx = np.random.randint(trainimg.shape[0], size=nsample)
for i in randidx:
curr_img = np.reshape(trainimg[i, :], (28, 28))
curr_label = np.argmax(trainlabel[i, :])
plt.matshow(curr_img, cmap=plt.get_cmap('gray'))
plt.title('' + str(i) + 'th Training Data'
+ 'Label is ' + str(curr_label))
print('' + str(i) + 'th Training Data'
+ 'Label is ' + str(curr_label))
這裏只展示一張圖片: