keras 使用時的一些注意事項

  1. fit 時顯示的loss
    訓練時顯示的loss和acc,是已經運行過的batch的平均loss。
    https://github.com/keras-team/keras/issues/10426
  1. 保存每個epoch的model
    filepath = 'saved-model-{epoch:02d}.h5'
    checkpointer = ModelCheckpoint(filepath=filepath, verbose=1, save_best_only=False, save_weights_only=False)

如果filepath是一個固定的字符串,那隻會保存最新的model文件。

  1. 顯存自適應,而不是一下佔滿
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True  # dynamically grow the memory used on the GPU
config.log_device_placement = False  # to log device placement (on which device the operation ran)
                                    # (nothing gets printed in Jupyter, only if you run it standalone)
sess = tf.Session(config=config)
set_session(sess) 
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章