(Click!)各層參數詳解
本程序結構:
版本:Tensordlow2.1,
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = '2'
import tensorflow as tf
from tensorflow.keras import datasets, Sequential, layers
from tensorflow.keras.layers import Conv2D, MaxPool2D, Dense
import matplotlib.pyplot as plt
# load datasets
(x, y), (x_val, y_val) = datasets.mnist.load_data()
train_db = tf.data.Dataset.from_tensor_slices((x, y))
val_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
def preprocess(x, y):
"""
:param x:
:param y:
:return:
"""
x = tf.cast(x, dtype=tf.float32) / 255.
x = tf.reshape(x, [-1, 28, 28, 1])
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
return x, y
# datasets preprocess
train_db = train_db.batch(100).map(preprocess)
val_db = val_db.batch(100).map(preprocess)
sample = next(iter(train_db))
print(sample[0].shape, sample[1].shape)
# LeNet-5 model
model = Sequential()
model.add(Conv2D(
filters=6,
kernel_size=(5, 5),
padding='valid',
activation='tanh',
input_shape=(28, 28, 1)
))
model.add(MaxPool2D(
pool_size=(2, 2)
))
model.add(Conv2D(
filters=16,
kernel_size=(5, 5),
padding='valid',
activation='tanh',
))
model.add(MaxPool2D(
pool_size=(2, 2)
))
model.add(layers.Flatten()) # 扁平化
model.add(Dense(
units=120,
activation='tanh'
))
model.add(Dense(
units=84,
activation='tanh'
))
model.add(Dense(
units=10,
activation="softmax"
))
model.summary()
model.compile(optimizer=tf.optimizers.SGD(learning_rate=0.05), loss='categorical_crossentropy',
metrics=["accuracy"])
history = model.fit(train_db, epochs=6, validation_data=val_db, shuffle=True) # 幾次嘗試訓練後發現epoch=6爲最佳
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('LeNet-5 loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['Train', 'loss'])
plt.show()
輸出
C:\Users\Administrator\anaconda3\python.exe E:/桌面文件/WorkStation/實戰/LeNet-5.py
(100, 28, 28, 1) (100, 10)
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 24, 24, 6) 156
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 12, 12, 6) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 8, 8, 16) 2416
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 4, 4, 16) 0
_________________________________________________________________
flatten (Flatten) (None, 256) 0
_________________________________________________________________
dense (Dense) (None, 120) 30840
_________________________________________________________________
dense_1 (Dense) (None, 84) 10164
_________________________________________________________________
dense_2 (Dense) (None, 10) 850
=================================================================
Total params: 44,426
Trainable params: 44,426
Non-trainable params: 0
_________________________________________________________________
Train for 600 steps, validate for 100 steps
Epoch 1/6
1/600 [..............................] - ETA: 13:51 - loss: 2.3123 - accuracy: 0.0600
13/600 [..............................] - ETA: 1:04 - loss: 2.1932 - accuracy: 0.2969
25/600 [>.............................] - ETA: 34s - loss: 1.9919 - accuracy: 0.4480
37/600 [>.............................] - ETA: 23s - loss: 1.7696 - accuracy: 0.5319
49/600 [=>............................] - ETA: 17s - loss: 1.5669 - accuracy: 0.5951
61/600 [==>...........................] - ETA: 14s - loss: 1.4139 - accuracy: 0.6366
73/600 [==>...........................] - ETA: 12s - loss: 1.2929 - accuracy: 0.6695
86/600 [===>..........................] - ETA: 10s - loss: 1.1928 - accuracy: 0.6964
99/600 [===>..........................] - ETA: 9s - loss: 1.1105 - accuracy: 0.7182
111/600 [====>.........................] - ETA: 8s - loss: 1.0394 - accuracy: 0.7366
123/600 [=====>........................] - ETA: 7s - loss: 0.9842 - accuracy: 0.7506
135/600 [=====>........................] - ETA: 6s - loss: 0.9407 - accuracy: 0.7610
......
564/600 [===========================>..] - ETA: 0s - loss: 0.4305 - accuracy: 0.8858
577/600 [===========================>..] - ETA: 0s - loss: 0.4243 - accuracy: 0.8875
589/600 [============================>.] - ETA: 0s - loss: 0.4177 - accuracy: 0.8893
600/600 [==============================] - 4s 7ms/step - loss: 0.4126 - accuracy: 0.8908 - val_loss: 0.1633 - val_accuracy: 0.9515
Epoch 2/6
......
1/600 [..............................] - ETA: 20s - loss: 0.0732 - accuracy: 0.9800
13/600 [..............................] - ETA: 4s - loss: 0.0775 - accuracy: 0.9800
25/600 [>.............................] - ETA: 3s - loss: 0.0618 - accuracy: 0.9844
37/600 [>.............................] - ETA: 2s - loss: 0.0604 - accuracy: 0.9854
49/600 [=>............................] - ETA: 2s - loss: 0.0578 - accuracy: 0.9849
61/600 [==>...........................] - ETA: 2s - loss: 0.0575 - accuracy: 0.9839
73/600 [==>...........................] - ETA: 2s - loss: 0.0594 - accuracy: 0.9842
84/600 [===>..........................] - ETA: 2s - loss: 0.0603 - accuracy: 0.9837
96/600 [===>..........................] - ETA: 2s - loss: 0.0634 - accuracy: 0.9824
108/600 [====>.........................] - ETA: 2s - loss: 0.0624 - accuracy: 0.9824
......
579/600 [===========================>..] - ETA: 0s - loss: 0.0591 - accuracy: 0.9827
590/600 [============================>.] - ETA: 0s - loss: 0.0584 - accuracy: 0.9829
600/600 [==============================] - 3s 5ms/step - loss: 0.0584 - accuracy: 0.9830 - val_loss: 0.0612 - val_accuracy: 0.9801
Process finished with exit code 0