Tensorflow2.x 訓練網絡時的指標輸出,以及模型結構圖導出

訓練指標輸出

1. 使用TensorBoard
2. 使用History類

TensorBorad

TensorBoard的Scalars可以可視化這些指標

使用步驟:

記錄訓練中的指標,需要執行以下操作:

  1. 創建KerasTensorBoard回調
  2. 指定日誌目錄
  3. 將TensorBoard回調傳遞Keras的Model.fit()

回調函數:

tf.keras.callbacks.TensorBoard(
           log_dir='logs', histogram_freq=0, write_graph=True, write_images=False,
          update_freq='epoch', profile_batch=2, embeddings_freq=0,
    embeddings_metadata=None, **kwargs
)

參數

  1. log_dir:將要由TensorBoard解析的日誌文件保存到的目錄路徑。
  2. histogram_freq:計算模型各層的激活度和權重直方圖的頻率。如果設置爲0,將不計算直方圖。必須爲直方圖可視化指定驗證數據。
  3. write_graph:是否在TensorBoard中可視化圖形。當write_graph設置爲True時,日誌文件可能會變得很大。
  4. write_images:是否編寫模型權重以在TensorBoard中可視化爲圖像。
  5. update_freq:‘batch’或’epoch’或整數。使用時’batch’,每批之後將損失和指標寫入TensorBoard。同樣適用於’epoch’。如果使用整數,假設1000,回調將每1000批將指標和損失寫入TensorBoard。請注意,過於頻繁地向TensorBoard寫入可能會減慢訓練速度。
  6. profile_batch:分析批次以採樣計算特徵。默認情況下,它將配置第二批。將profile_batch = 0設置爲禁用分析。必須在TensorFlow急切模式下運行。
  7. embeddings_freq:嵌入層可視化的頻率(以曆元計)。如果設置爲0,則嵌入將不可見。
  8. embeddings_metadata:將層名稱映射到文件名的字典,該嵌入層的元數據保存在該文件名中。查看 有關元數據文件格式的 詳細信息。如果相同的元數據文件用於所有嵌入層,則可以傳遞字符串。

定義好回調函數後,在fit()函數中加入參數
如下:

 logdir = "logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")
 tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)
.....(省略掉的代碼)
history = model.fit(train_data,epochs=5,validation_data=test_data,validation_freq=1,
             callbacks=[tensorboard_callback]
  )

然後在終端,使用 tensorboard --logdir log/,就會出現下面的信息:

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.1.0 at http://localhost:6007/ (Press CTRL+C to quit)

進入連接即可出現網頁,就會顯示圖表。
模型結構圖
在這裏插入圖片描述

使用History類

這種方式比較簡單,

history = model.fit(train_data,epochs=5,validation_data=test_data,validation_freq=1,
             callbacks=[tensorboard_callback]

fit() 會返回一個History的類,它的History.history屬性記錄了訓練時期(每個epoch),訓練損失和準確率以及驗證損失和驗證準確率。
如下:

model.compile(optimizer=keras.optimizers.Adam(lr=0.01),
        loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy']
)
history = model.fit(train_data,epochs=5,validation_data=test_data,validation_freq=1,
                    # callbacks=[tensorboard_callback]
                    )
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label = 'val_loss')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
# plt.ylim([0.5, 1])
plt.legend(loc='lower right')
plt.show()

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章