(三)Tensorflow學習——mnist數據集簡介

導入相關包

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

# tensorflow自帶的一些數據集
from tensorflow.examples.tutorials.mnist import input_data

加載數據集

在該目錄下,建立一個空文件夾data,加載mnist數據集時,會自動從網上下載

print('Download and Extract MNIST dataset')
mnist = input_data.read_data_sets('data/', one_hot=True)
print('type of "mnist" is %s' % (type(mnist)))
print('number of train data is %d' % (mnist.train.num_examples))
print('number of test data is %d' % (mnist.test.num_examples))

在這裏插入圖片描述
在這裏插入圖片描述

mnist數據集的描述信息

trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels
print('type of "trainimg" is %s' % (type(trainimg)))
print('type of "trainlabel" is %s' % (type(trainlabel)))
print('type of "testimg" is %s' % (type(testimg)))
print('type of "testlabel" is %s' % (type(testlabel)))
print('shape of "trainimg" is %s' % (trainimg.shape,))
print('shape of "trainlabel" is %s' % (trainlabel.shape,))
print('shape of "testimg" is %s' % (testimg.shape,))
print('shape of "testlabel" is %s' % (testlabel.shape,))

輸出結果:
在這裏插入圖片描述

打印原數據集

nsample = 5
randidx = np.random.randint(trainimg.shape[0], size=nsample)

for i in randidx:
    curr_img = np.reshape(trainimg[i, :], (28, 28))
    curr_label = np.argmax(trainlabel[i, :])
    plt.matshow(curr_img, cmap=plt.get_cmap('gray'))
    plt.title('' + str(i) + 'th Training Data'
              + 'Label is ' + str(curr_label))
    print('' + str(i) + 'th Training Data'
              + 'Label is ' + str(curr_label))

這裏只展示一張圖片:
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章