Tensorflow2.x.x最基礎的神經網絡(ANN)

Tensorflow2.x.x最基礎的神經網絡(ANN)

本章節主要使用Tensorflow2.x.x來搭建ANN神經網絡。

ANN原理

這裏直接放上小夥伴ANN的原理博客~

實現

使用ANN實現對MNIST數據集的分類。

import tensorflow as tf
# mnist數據集
from tensorflow.keras.datasets import mnist
# Adam優化器
from tensorflow.keras.optimizers import Adam
# 交叉熵損失函數,一般用於多分類
from tensorflow.keras.losses import CategoricalCrossentropy
# 模型和網絡層
from tensorflow.keras import Model, layers

# 批次大小
BATCH_SIZE = 128
# 迭代次數
EPOCHS = 10
# 加載mnist的訓練、測試數據集
train, test = mnist.load_data()
# 數據集的預處理
@tf.function
def preprocess(x, y):
    # 將x一維數據轉爲3維灰度圖
    x = tf.reshape(x, [28, 28, 1])
    # 將x的範圍由[0, 255]爲[0, 1]
    x = tf.image.convert_image_dtype(x, tf.float32)
    # 將y數字標籤進行獨熱編碼
    y = tf.one_hot(y, 10)
    # 返回處理後的x和y
    return x, y

# 使用Dataset來減少內存的使用
train = tf.data.Dataset.from_tensor_slices(train)
# 對數據進行預處理並且給定BATCH_SIZE
train = train.map(preprocess).batch(BATCH_SIZE)

# test數據集同理
test = tf.data.Dataset.from_tensor_slices(test)
test = test.map(preprocess).batch(BATCH_SIZE)

# 搭建模型(只是其中的一種搭建方式而已)
x = layers.Input(shape=(28, 28, 1))                 # 輸入爲x, 大小爲 28*28*1
y = layers.Flatten()(x)                             # 將高維數據扁平化
y = layers.Dense(1024, activation='relu')(y)        # 輸出1024個神經元的全網絡層
y = layers.Dense(512, activation='relu')(y)         # 輸出512個神經元的全網絡層
y = layers.Dense(256, activation='relu')(y)         # 輸出256個神經元的全網絡層
y = layers.Dense(128, activation='relu')(y)         # 輸出128個神經元的全網絡層
y = layers.Dense(64, activation='relu')(y)          # 輸出64個神經元的全網絡層
y = layers.Dense(32, activation='relu')(y)          # 輸出32個神經元的全網絡層
y = layers.Dense(10, activation='softmax')(y)       # 輸出10個神經元的全網絡層,最後一層使用了softmax進行激活,原因是我們希望提前[0, 1]之間的概率

# 創建模型
ann = Model(x, y)
# 編譯模型,選擇優化器、評估標準、損失函數
ann.compile(optimizer=Adam(), metrics=['acc'], loss=CategoricalCrossentropy())
# 進行模型訓練
history = ann.fit(train, epochs=EPOCHS)
# 測試集的評估
score = ann.evaluate(test)
# 打印評估成績
print('loss: {0}, acc: {1}'.format(score[0], score[1])) # loss: 0.11106619730560828, acc: 0.9769999980926514

# 繪製訓練過程中每個epoch的loss和acc的折線圖
import matplotlib.pyplot as plt
# history對象中有history字典, 字典中存儲着“損失”和“評估標準”
epochs = range(EPOCHS)
fig = plt.figure(figsize=(15, 5), dpi=100)

ax1 = fig.add_subplot(1, 2, 1)
ax1.plot(epochs, history.history['loss'])
ax1.set_title('loss graph')
ax1.set_xlabel('epochs')
ax1.set_ylabel('loss val')

ax2 = fig.add_subplot(1, 2, 2)
ax2.plot(epochs, history.history['acc'])
ax2.set_title('acc graph')
ax2.set_xlabel('epochs')
ax2.set_ylabel('acc val')

fig.show()

輸出結果如下:
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章