# 【小白學PyTorch】19 TF2模型的存儲與載入

【新聞】：機器學習煉丹術的粉絲的人工智能交流羣已經建立，目前有目標檢測、醫學圖像、時間序列等多個目標爲技術學習的分羣和水羣嘮嗑的總羣，歡迎大家加煉丹兄爲好友，加入煉丹協會。微信：cyx645016617.

## 1 模型的構建

``````import tensorflow.keras as keras

class CBR(keras.layers.Layer):
def __init__(self,output_dim):
super(CBR,self).__init__()
self.conv = keras.layers.Conv2D(filters=output_dim, kernel_size=4, padding='same', strides=1)
self.bn = keras.layers.BatchNormalization(axis=3)
self.ReLU = keras.layers.ReLU()

def call(self, inputs):
inputs = self.conv(inputs)
inputs = self.ReLU(self.bn(inputs))
return inputs

class MyNet(keras.Model):
def __init__ (self):
super(MyNet,self).__init__()
self.cbr1 = CBR(16)
self.maxpool1 = keras.layers.MaxPool2D(pool_size=(2,2))
self.cbr2 = CBR(32)
self.maxpool2 = keras.layers.MaxPool2D(pool_size=(2,2))

def call(self, inputs):
inputs = self.maxpool1(self.cbr1(inputs))
inputs = self.maxpool2(self.cbr2(inputs))
return inputs

model = MyNet()
``````

``````model.build((16,224,224,3))
print(model.summary())
``````

``````model.build((16,224,224,1))
print(model.summary())
``````

## 2 結構參數的存儲與載入

``````model.save('save_model.h5')
``````

## 3 參數的存儲與載入

``````model.save_weights('model_weight')
new_model = MyNet()
``````

``````# 看一下原來的模型和載入的模型預測相同的樣本的輸出
test = tf.ones((1,8,8,3))
prediction = model.predict(test)
new_prediction = new_model.predict(test)
print(prediction,new_prediction)
>>> [[[[0.02559286]]]] [[[[0.02559286]]]]
``````

## 4 結構的存儲與載入

• `model.get_config()`
• `model.to_json()`

``````# 第一種方法
config = model.get_config()
reinitialized_model = keras.Model.from_config(config)
# 第二種方法
json_config = model.to_json()
# 把json寫的文件中
with open('model_config.json', 'w') as json_file:
json_file.write(json_config)
# 讀取本地json文件
with open('model_config.json') as json_file: