tensorflow2.0的模型保存加載的幾個方法

tensorflow2.0中模型的加載更加便捷。
我在github上新建了一個有關ner的項目,其中有對tensorflow2.0的api的一些詳細使用。NER
想了解更多tensorflow2.0中模型存儲加載方法,可以直接到其官方網站tf2.0.

我們這裏說一下幾個保存權重的方法:
假如當前建立的模型代碼如下:

import tensorflow as tf
from tensorflow import keras
def get_model():
  # Create a simple model.
  inputs = keras.Input(shape=(32,))
  outputs = keras.layers.Dense(1)(inputs)
  model = keras.Model(inputs, outputs)
  model.compile(optimizer='adam', loss='mean_squared_error')
  return model

model = get_model()
opt = tf.keras.optimizers.Adam(0.1)
checkpoint_dir="./checkpoint"

1、保存檢查點
具體api如下:

tf.train.Checkpoint
tf.train.CheckpointManager

使用以上兩個api就可以保存訓練中所有的權重。具體操作如下:
首先創建檢查點

ckpt = tf.train.Checkpoint(optimizer=opt,model=model)
manager = tf.train.CheckpointManager(ckpt,
                                   checkpoint_dir,
                                   max_to_keep=3)

具體參數含義可以直接help查看api中參數的解釋。
創建完檢查點後,如果存在舊模型,就需要從舊模型中恢復權重。操作如下:

ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
     print("Restored from {}".format(manager.latest_checkpoint))
else:
     print("Initializing from scratch.")

然後再看看如何保存模型
在訓練過程中我們可以直接使用manager的功能save進行存儲,相關代碼如下:

for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))

2、keras內置存儲功能
keras有兩種可以存儲載入模型的內置函數

  • save 與load_model函數對
    具體api:
model.save() 或者 tf.keras.models.save_model()
tf.keras.models.load_model()

保存模型:

model.save(checkpoint_dir)

默認存儲結果會有三個:

assets  saved_model.pb  variables

也可以直接指定存儲爲HDF5的格式

model.save(checkpoint_dir + '/' + 'model.h5')

載入模型:

recover_model = keras.models.load_model(checkpoint_dir)

如果是HDF5格式文件:

recover_model = keras.models.load_model(checkpoint_dir+ '/' + 'model.h5')

載入模型後生成新的對象recover_model,會複製原來model的所有功能。後續的訓練測試使用recover_model

  • keras.Model 內置的save_weights與load_weights、get_weights與set_weights。
    其中常用的是save_weights與load_weights。
    save_weights可以有兩種存儲方式,tensorflow格式與h5格式。默認爲使用tensorflow方式也是類似於檢查點的方式進行存儲。
    具體操作如下:
model.save_weights(path=checkpoint_dir)

存儲爲h5

model.save_weights(path=checkpoint_dir+/model.h5’,save_format='h5')

載入方式:
非h5文件:

model.load_weights(path=checkpoint_dir)

h5文件的載入:

model.load_weights(path=checkpoint_dir+'/model.h5')

以上就是訓練過程常用的模型存儲與加載的方式。可以看到tf2.0簡化了tf1.0中的許多操作,對於用戶來說已經是非常友好。擁抱pytorch的同學們可以再回來繼續當tfboys。但是模型的部署怎麼搞?以上幾種辦法,都需要搭建原始的結構,然後載入權重。這和環境上部署毛都不沾。
接下來介紹一個保存環境部署的模型的方法。具體應用在我的ner項目中已經體現,具體文件是run_pb.py。只有簡單的四行代碼,就可以載入模型。

首先看看怎麼保存模型:
可以先恢復檢查點,但是忽略優化器之類的權重。這裏參考我ner項目中的infer代碼。可用重新建一個圖,不包含任何優化器節點。
然後恢復模型:

ner = ner_model(config,training=False)
#從訓練的檢查點恢復權重
ckpt = tf.train.Checkpoint(ner=ner)
latest_ckpt = tf.train.latest_checkpoint(checkpoint_dir + 'trains')
#添加expect_partial()忽略優化器相關節點
status = ckpt.restore(latest_ckpt).expect_partial()

恢復之後,保存模型:

tf.saved_model.save(ner, checkpoint_dir + 'infers/')

經過這一步,我們可以看到,在checkpoint_dir + 'infers/'目錄下有:

assets  saved_model.pb  variables

一個pb文件,兩個檢查點目錄,這兩個目錄裏面東西的作用,目前未知。

然後參看run_pb.py中代碼。僅僅只有四行代碼,我們就可以部署訓練號的模型到環境上。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章