Tensorflow2.0之模型權值的保存與恢復(Checkpoint)

介紹

很多時候,我們希望在模型訓練完成後能將訓練好的參數(變量)保存起來。在需要使用模型的其他地方載入模型和參數,就能直接得到訓練好的模型。

TensorFlow 提供了 tf.train.Checkpoint 這一強大的變量保存與恢復類,可以使用其 save() 和 restore() 方法將 TensorFlow 中所有包含 Checkpointable State 的對象進行保存和恢復。具體而言,tf.keras.optimizer 、 tf.Variable 、 tf.keras.Layer 或者 tf.keras.Model 實例都可以被保存。

保存變量

# train.py 模型訓練階段

model = MyModel()
checkpoint = tf.train.Checkpoint(myModel=model)
# ...(模型訓練代碼)
# 模型訓練完畢後將參數保存到文件
checkpoint.save('./save/model.ckpt')

這裏 tf.train.Checkpoint() 接受的初始化參數比較特殊,是一個 **kwargs 。具體而言,是一系列的鍵值對,鍵名可以隨意取,值爲需要保存的對象。在這裏,我們取鍵名爲 myModel,指定保存對象爲 model。如果我們希望保存其他對象如 Optimizer 的參數,我們可以這樣寫:

checkpoint = tf.train.Checkpoint(myModel=model, myOptimizer=optimizer)

訓練完後,checkpoint 文件會出現在 ‘./save/’ 文件夾下,‘model.ckpt’ 是這些文件的前綴。如果我們只調用了一次 checkpoint.save 函數,那麼在 ‘./save/’ 文件夾下會出現名爲 checkpoint 、 model.ckpt-1.index 、 model.ckpt-1.data-00000-of-00001 的三個文件,這些文件就記錄了變量信息。checkpoint.save() 方法可以運行多次,每運行一次都會得到一個.index 文件和.data 文件,序號依次累加。

恢復變量

當在其他地方需要爲模型重新載入之前保存的參數時,需要再次實例化一個 checkpoint,同時保持鍵名的一致。再調用 checkpoint 的 restore 方法。

# test.py 模型使用階段

model_to_be_restored = MyModel()
checkpoint = tf.train.Checkpoint(myModel=model_to_be_restored)  # 實例化Checkpoint,指定恢復對象爲model
checkpoint.restore(tf.train.latest_checkpoint('./save'))  # 從文件恢復模型參數

當保存了多個文件時,我們往往想載入最近的一個。可以使用 tf.train.latest_checkpoint(save_path) 這個輔助函數返回目錄下最近一次 checkpoint 的文件名。例如如果 save 目錄下有 model.ckpt-1.index 到 model.ckpt-10.index 的 10 個保存文件, tf.train.latest_checkpoint(’./save’) 即返回 ./save/model.ckpt-10 。

有限制地保留 Checkpoint 文件

在模型的訓練過程中,我們往往每隔一定步數保存一個 Checkpoint 並進行編號。不過很多時候我們會有這樣的需求:

  • 在長時間的訓練後,程序會保存大量的 Checkpoint,但我們只想保留最後的幾個 Checkpoint;

  • Checkpoint 默認從 1 開始編號,每次累加 1,但我們可能希望使用別的編號方式(例如使用當前 epoch 的編號作爲文件編號)。

這時,我們可以使用 TensorFlow 的 tf.train.CheckpointManager 來實現以上需求。具體而言,在定義 Checkpoint 後接着定義一個 CheckpointManager:

checkpoint = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(checkpoint, directory='./save', checkpoint_name='model.ckpt', max_to_keep=k)

在需要保存模型的時候,我們直接使用 manager.save() 即可。如果我們希望自行指定保存的 Checkpoint 的編號,則可以在保存時加入 checkpoint_number 參數。例如 manager.save(checkpoint_number=100) 。

實例

我們通過對 MNIST 數據集的訓練來舉例:

1、定義模型及訓練過程

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers

mnist = keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis].astype(np.float32)
x_test = x_test[..., tf.newaxis].astype(np.float32)

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(x_test.shape[0])

class MyModel(keras.Model):
    # Set layers.
    def __init__(self):
        super(MyModel, self).__init__()
        # Convolution Layer with 32 filters and a kernel size of 5.
        self.conv1 = layers.Conv2D(32, kernel_size=5, activation=tf.nn.relu)
        # Max Pooling (down-sampling) with kernel size of 2 and strides of 2.
        self.maxpool1 = layers.MaxPool2D(2, strides=2)

        # Convolution Layer with 64 filters and a kernel size of 3.
        self.conv2 = layers.Conv2D(64, kernel_size=3, activation=tf.nn.relu)
        # Max Pooling (down-sampling) with kernel size of 2 and strides of 2.
        self.maxpool2 = layers.MaxPool2D(2, strides=2)

        # Flatten the data to a 1-D vector for the fully connected layer.
        self.flatten = layers.Flatten()

        # Fully connected layer.
        self.fc1 = layers.Dense(1024)
        # Apply Dropout (if is_training is False, dropout is not applied).
        self.dropout = layers.Dropout(rate=0.5)

        # Output layer, class prediction.
        self.out = layers.Dense(10)

    # Set forward pass.
    def call(self, x, is_training=False):
        x = tf.reshape(x, [-1, 28, 28, 1])
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.dropout(x, training=is_training)
        x = self.out(x)
        if not is_training:
            # tf cross entropy expect logits without softmax, so only
            # apply softmax when not training.
            x = tf.nn.softmax(x)
        return x

model = MyModel()

loss_object = keras.losses.SparseCategoricalCrossentropy()
optimizer = keras.optimizers.Adam()

@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

2、保存模型參數

2.1 不限制 checkpoint 文件個數

EPOCHS = 5

checkpoint = tf.train.Checkpoint(myAwesomeModel=model)
for epoch in range(EPOCHS):
    for images, labels in train_ds:
        train_step(images, labels)
    path = checkpoint.save('./save/model.ckpt')
    print("model saved to %s" % path)

2.2 限制 checkpoint 文件個數

EPOCHS = 5

checkpoint = tf.train.Checkpoint(myAwesomeModel=model)
manager = tf.train.CheckpointManager(checkpoint, directory='./save', max_to_keep=3)
for epoch in range(EPOCHS):
    for images, labels in train_ds:
        train_step(images, labels)
    path = manager.save(checkpoint_number=epoch)
    print("model saved to %s" % path)

3、加載模型參數

model_to_be_restored = MyModel()
checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)      
checkpoint.restore(tf.train.latest_checkpoint('./save')) 
for test_images, test_labels in test_ds:
    y_pred = np.argmax(model_to_be_restored.predict(test_images), axis=-1)
    print("test accuracy: %f" % (sum(tf.cast(y_pred == test_labels, tf.float32)) / x_test.shape[0]))
test accuracy: 0.989600

參考資料

簡單粗暴 TensorFlow 2

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章