單隱藏層神經網絡(mnist手寫數字識別)

# 1、載入數據
import numpy as np
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
# 讀取mnist數據
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

# 2.建立模型

# 2.1 構建輸入層
x = tf.placeholder(tf.float32, [None, 784], name='X')
y = tf.placeholder(tf.float32, [None, 10], name='Y')

# 2.2 構建隱藏層
# 隱藏層神經元數量(隨意設置)
H1_NN = 256
# 權重
W1 = tf.Variable(tf.random_normal([784, H1_NN]))
# 偏置項
b1 = tf.Variable(tf.zeros([H1_NN]))

Y1 = tf.nn.relu(tf.matmul(x, W1) + b1)

# 2.3 構建輸出層
W2 = tf.Variable(tf.random_normal([H1_NN, 10]))
b2 = tf.Variable(tf.zeros([10]))

forward = tf.matmul(Y1, W2) + b2
pred = tf.nn.softmax(forward)

# 3.訓練模型

# 3.1 定義損失函數
# tensorflow提供了下面的函數,用於避免log(0)值爲Nan造成數據不穩定
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=forward, labels=y))
# # 交叉熵損失函數
# loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))

# 3.2 設置訓練參數
train_epochs = 40 #訓練輪數
batch_size = 50 #單次訓練樣本數(批次大小)
# 一輪訓練的批次數
total_batch = int(mnist.train.num_examples/batch_size)
display_step = 1 #顯示粒數
learning_rate = 0.01 #學習率

# 3.2 選擇優化器
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss_function)

# 3.3定義準確率
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(pred, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# 3.4 模型的訓練
# 記錄訓練開始的時間
from time import time
startTime = time()

sess = tf.Session()
sess.run(tf.global_variables_initializer())

for epoch in range(train_epochs):
      for batch in range(total_batch):
            # 讀取批次訓練數據
            xs, ys = mnist.train.next_batch(batch_size)
            # 執行批次訓練
            sess.run(optimizer,feed_dict={x:xs, y:ys})
      # 在total_batch批次數據訓練完成後,使用驗證數據計算誤差和準確率,驗證集不分批
      loss, acc = sess.run([loss_function,accuracy],feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
      # 打印訓練過程中的詳細信息
      if (epoch+1) % display_step == 0:
            print('訓練輪次:','%02d' % (epoch+1),
                  '損失:','{:.9f}'.format(loss),
                  '準確率:','{:.4f}'.format(acc))
print('訓練結束')
# 顯示總運行時間
duration = time() - startTime
print("總運行時間爲:", "{:.2f}".format(duration))

# 4.評估模型
accu_test = sess.run(accuracy,
                     feed_dict={x:mnist.test.images, y:mnist.test.labels})
print('測試集準確率:', accu_test)

# 5.應用模型
prediction_result = sess.run(tf.argmax(pred, 1), feed_dict={x:mnist.test.images})
# 查看預測結果的前10項
print("前10項的結果:", prediction_result[0:10])

# 5.1找出預測錯誤的樣本
compare_lists = prediction_result==np.argmax(mnist.test.labels,1)
print(compare_lists)
err_lists = [i for i in range(len(compare_lists)) if compare_lists[i]==False]
print('預測錯誤的圖片:', err_lists)
print('預測錯誤圖片的總數:', len(err_lists))

# 定義一個輸出錯誤分類的函數
import numpy as np
def print_predict_errs(labels,#標籤列表
                       prediction):#預測值列表
    count = 0
    compare_lists = (prediction == np.argmax(labels, 1))
    err_lists = [i for i in range(len(compare_lists)) if compare_lists[i] == False]
    for x in err_lists:
        print('index='+str(x)+'標籤值=',np.argmax(labels[x]), '預測值=', prediction[x])
        count = count + 1
    print("總計:"+str(count))
print_predict_errs(labels=mnist.test.labels, prediction=prediction_result)

# 可視化
import matplotlib.pyplot as plt


def plot_images_labels_prediction(images,      #圖像列表
                                   labels,      #標籤列表
                                   predication, #預測值列表
                                   index,       #從第index個開始顯示
                                   num=10):    # 缺省一次顯示10幅
      fig = plt.gcf()               #獲取當前圖表,get current figure
      fig.set_size_inches(10,12)    #設爲英寸,1英寸=2.53釐米
      if num > 25:
            num = 25          #最多顯示25個子圖
      for i in range(0,num):
            ax = plt.subplot(5,5, i+1)    #獲取當前要處理的子圖
            # 顯示第index圖像
            ax.imshow(np.reshape(images[index],(28,28)),cmap='binary')

            # 構建該圖上顯示的title
            title = 'label=' + str(np.argmax(labels[index]))
            if len(predication) > 0:
                  title += ",predict=" + str(predication[index])

            # 顯示圖上的title信息
            ax.set_title(title,fontsize=10)
            ax.set_xticks([]) # 不顯示座標軸
            ax.set_yticks([])
            index += 1

      plt.show()

plot_images_labels_prediction(mnist.test.images,
                               mnist.test.labels,
                               prediction_result, 10,25)
plot_images_labels_prediction(mnist.test.images,
                              mnist.test.labels,
                              prediction_result,610, 20)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章