訓練指標輸出
1. 使用TensorBoard
2. 使用History類
TensorBorad
TensorBoard的Scalars可以可視化這些指標
使用步驟:
記錄訓練中的指標,需要執行以下操作:
- 創建KerasTensorBoard回調
- 指定日誌目錄
- 將TensorBoard回調傳遞Keras的Model.fit()
回調函數:
tf.keras.callbacks.TensorBoard(
log_dir='logs', histogram_freq=0, write_graph=True, write_images=False,
update_freq='epoch', profile_batch=2, embeddings_freq=0,
embeddings_metadata=None, **kwargs
)
參數:
- log_dir:將要由TensorBoard解析的日誌文件保存到的目錄路徑。
- histogram_freq:計算模型各層的激活度和權重直方圖的頻率。如果設置爲0,將不計算直方圖。必須爲直方圖可視化指定驗證數據。
- write_graph:是否在TensorBoard中可視化圖形。當write_graph設置爲True時,日誌文件可能會變得很大。
- write_images:是否編寫模型權重以在TensorBoard中可視化爲圖像。
- update_freq:‘batch’或’epoch’或整數。使用時’batch’,每批之後將損失和指標寫入TensorBoard。同樣適用於’epoch’。如果使用整數,假設1000,回調將每1000批將指標和損失寫入TensorBoard。請注意,過於頻繁地向TensorBoard寫入可能會減慢訓練速度。
- profile_batch:分析批次以採樣計算特徵。默認情況下,它將配置第二批。將profile_batch = 0設置爲禁用分析。必須在TensorFlow急切模式下運行。
- embeddings_freq:嵌入層可視化的頻率(以曆元計)。如果設置爲0,則嵌入將不可見。
- embeddings_metadata:將層名稱映射到文件名的字典,該嵌入層的元數據保存在該文件名中。查看 有關元數據文件格式的 詳細信息。如果相同的元數據文件用於所有嵌入層,則可以傳遞字符串。
定義好回調函數後,在fit()函數中加入參數
如下:
logdir = "logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)
.....(省略掉的代碼)
history = model.fit(train_data,epochs=5,validation_data=test_data,validation_freq=1,
callbacks=[tensorboard_callback]
)
然後在終端,使用 tensorboard --logdir log/
,就會出現下面的信息:
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.1.0 at http://localhost:6007/ (Press CTRL+C to quit)
進入連接即可出現網頁,就會顯示圖表。
使用History類
這種方式比較簡單,
history = model.fit(train_data,epochs=5,validation_data=test_data,validation_freq=1,
callbacks=[tensorboard_callback]
fit() 會返回一個History的類,它的History.history屬性記錄了訓練時期(每個epoch),訓練損失和準確率以及驗證損失和驗證準確率。
如下:
model.compile(optimizer=keras.optimizers.Adam(lr=0.01),
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
history = model.fit(train_data,epochs=5,validation_data=test_data,validation_freq=1,
# callbacks=[tensorboard_callback]
)
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label = 'val_loss')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
# plt.ylim([0.5, 1])
plt.legend(loc='lower right')
plt.show()