一、瞭解MNIST數據集
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import matplotlib.pyplot as plt
import numpy as np
#0 讀取mnist數據集
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
#1 讀取訓練集的一張圖片與標籤
batch_xs, batch_ys = mnist.train.next_batch(1)
img = tf.reshape(batch_xs,[28,28])
label = tf.argmax(batch_ys,1)
#2 讀取測試集的一張圖片與標籤
img_tt = mnist.test.images[0]
img_t = tf.reshape(img_tt,[28,28])
label_tt = mnist.test.labels[0]
label_tt = np.reshape(label_tt,[1,10])
label_t = tf.argmax(label_tt,1)
with tf.Session() as sess:
# 訓練集 打印標籤 像素值
print("訓練集:%d"%sess.run(label))
im = sess.run(img)
for i in range(28):
for j in range(28):
print("%d "%round(im[i][j]),end='')
print()
#測試集 打印標籤 像素值
print("測試集:%d"%sess.run(label_t))
im_t = sess.run(img_t)
for i in range(28):
for j in range(28):
print("%d "%round(im_t[i][j]),end='')
print()
#顯示圖片
plt.figure()
plt.subplot(121)
plt.imshow(im,cmap = 'gray')
plt.subplot(122)
plt.imshow(im_t,cmap = 'gray')
plt.show()
輸出結果如下:
二、 簡單的訓練、測試和保存模型
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
#1 讀取訓練數據
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
#2 建立模型
x = tf.placeholder(tf.float32, [None, 784])
w = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y=tf.nn.softmax(tf.matmul(x,w)+b)
#3 損失函數 交叉熵
y_ = tf.placeholder("float", [None,10])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
#4 優化訓練
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
#5 開啓會話
with tf.Session() as sess:
# 初始化
sess.run(tf.global_variables_initializer())
# 訓練
for i in range(500):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(optimizer, feed_dict={x: batch_xs, y_: batch_ys})
# 測試模型
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
accuracy = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}) * 100
print(accuracy)
# 保存模型
model_path = "model/"
model_name = "model_" + (str(accuracy))[:4] + "%"
tf.train.Saver().save(sess,model_path+model_name)
三、簡單的使用模型
import tensorflow as tf
#1 模型參數
x = tf.placeholder(tf.float32, [None, 784])
w = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y=tf.nn.softmax(tf.matmul(x,w)+b)
#y_ = tf.placeholder("float", [None,10])
#2 讀取圖片
img = tf.read_file('7.jpg') #讀取 彩色圖片
im_3 = tf.image.decode_jpeg(img, channels=3) #解碼
im_resize = tf.image.resize_images(im_3,[28,28]) #縮放成28X28
im_gry = tf.squeeze(tf.image.rgb_to_grayscale(im_resize),2) #灰度化 降維變成二維
im_reshape = tf.reshape(im_gry,[1,784]) #改變形狀
#3 開啓會話
with tf.Session() as sess:
# 讀取模型
tf.train.Saver().restore(sess,"model/model_91.4%")
# 輸入待檢測圖像數據
xx = sess.run(im_reshape)
# 進行識別並輸出
result = sess.run(tf.argmax(y, 1), feed_dict={x: xx})
print(result)
結果: 結果: