還是入坑了的TF2_1——實現全連接手寫體識別

學習前言

還是要開始學習tf2呀,看看有沒有意思!這玩意好像就是Keras呀,老熟人了,感覺很快樂!
在這裏插入圖片描述

重要函數

1、Model

Model用於建立模型。與Keras一樣,可以傳入Inputs和Outputs作爲輸入輸出。很簡單就可以構建一個模型。
使用方法如下:

# 建立模型
model = Model(inputs,out)

2、Input

Input用於建立輸入量。與Keras一樣,需要指定輸入進來的內容的shape,可以是圖片也可以是一維向量之類的。
使用方法如下:

# 作爲輸入
inputs = Input([28,28])

3、Dense

Dense用於往model中添加全連接層。全連接層示意圖如下。
在這裏插入圖片描述
具體而言,簡單的BP神經網絡中,輸入層到隱含層中間的權值連接,其實與全連接層的意義相同。
與Keras一樣,需要指定全連接的神經元數量,還可以指定激活函數。

x = Flatten(input_shape=(28, 28))(inputs)
x = Dense(128, activation='relu')(x)
x = Dropout(0.2)(x)
out = Dense(10, activation='softmax')(x)

4、model.compile

model.compile主要用於定義loss函數和優化器。
其調用方式如下:

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

其中loss用於定義計算損失的損失函數,其可以選擇的內容如下:
1、mse:均方根誤差,常用於迴歸預測。
2、categorical_crossentropy:亦稱作多類的對數損失,注意使用該目標函數時,需要將標籤轉化爲形如(nb_samples, nb_classes)的二值序列,常用於分類。
3、sparse_categorical_crossentrop:如上,但接受稀疏標籤。

metrics=[‘accuracy’]常用於分類運算中,accuracy代表計算分類精確度。

5、model.fit

用於接收訓練數據用於訓練:

# 利用fit進行訓練
model.fit(x_train, y_train, epochs=5)

全部代碼

import tensorflow as tf
from tensorflow.keras.layers import Flatten,Dense,Dropout,Input
from tensorflow.keras.models import Model
print(tf.__version__)
print(tf.keras.__version__)
# 載入Mnist手寫數據集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# 作爲輸入
inputs = Input([28,28])
x = Flatten(input_shape=(28, 28))(inputs)
x = Dense(128, activation='relu')(x)
x = Dropout(0.2)(x)
out = Dense(10, activation='softmax')(x)

# 建立模型
model = Model(inputs,out)

# 設定優化器,loss,計算準確率
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 利用fit進行訓練
model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test,  y_test, verbose=2)

輸出如下:

2.0.0
2.2.4-tf
Epoch 1/5
60000/60000 [==============================] - 3s 45us/sample - loss: 0.2989 - accuracy: 0.9121
Epoch 2/5
60000/60000 [==============================] - 2s 36us/sample - loss: 0.1415 - accuracy: 0.9577
Epoch 3/5
60000/60000 [==============================] - 2s 38us/sample - loss: 0.1067 - accuracy: 0.9674
Epoch 4/5
60000/60000 [==============================] - 2s 35us/sample - loss: 0.0893 - accuracy: 0.9726
Epoch 5/5
60000/60000 [==============================] - 2s 36us/sample - loss: 0.0746 - accuracy: 0.9769
10000/1 - 0s - loss: 0.0427 - accuracy: 0.9756
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章