tensorflow2 解析模型參數
introduction
CNN模型是計算機視覺裏面常用的工具了,訓模師訓好模型可能還需要其他的操作,比如可能做剪枝,或者量化,需要對模型的參數做一些操作。這時候就需要解析模型的參數了。這篇文章主要敘述一下,在tensorflow2 下怎麼解析模型參數。
step by step
1. 先構建並保存一個模型
tensorflow2 構建模型還是首選keras接口, 在模型保存方面有好幾個接口可以選擇,觸類旁通, 此處我習慣使用tf.save_model。
class PlainCNN(tf.keras.Model):
def __init__(self,
kernel_initializer='glorot_normal'):
super(PlainCNN, self).__init__()
self.conv1= tf.keras.layers.Conv2D( filters=32,
kernel_size=(3, 3),
strides=2,
padding='same',
use_bias=False,
kernel_initializer=kernel_initializer)
self.conv2 = tf.keras.layers.Conv2D(filters=64,
kernel_size=(3, 3),
strides=2,
padding='same',
use_bias=False,
kernel_initializer=kernel_initializer)
self.conv3 = tf.keras.layers.Conv2D(filters=128,
kernel_size=(3, 3),
strides=2,
padding='same',
use_bias=False,
kernel_initializer=kernel_initializer)
@tf.function(input_signature=[tf.TensorSpec([None, None, None, 3], tf.float32)])
def call(self, inputs, training=False):
x1 = self.conv1(inputs)
x2 = self.conv2(x1)
x3 = self.conv3(x2)
return x3
import numpy as np
model = PlainCNN()
image=np.zeros(shape=(1,48,48,3),dtype=np.float32)
x=model(image)
# 將會打印出模型的信息:
# the result shape is : (1, 6, 6, 128)
print('the result shape is :', x.shape)
##保存模型爲 tmp_model
tf.saved_model.save(model,'tmp_model')
此時我們有了tmp_model 模型目錄
2. 加載模型
model=tf.saved_model.load('./tmp_model')
###這時model可以理解爲上面的代碼,可以直接inference, 也可以依次獲取每個變量的variables,####或者trainable_variables,等等
print(model.conv1.variables)
輸出實例
‘ListWrapper([<tf.Variable ‘conv2d/kernel:0’ shape=(3, 3, 3, 32) dtype=float32, numpy=
array([[[[ 8.40592012e-02, -1.82091370e-02, 6.67047426e-02,
-1.57765336e-02, 5.53220101e-02, 6.69927672e-02,
1.18893180e-02, -3.55695821e-02, 2.96269413e-02,
6.24449886e-02, 8.93794820e-02, 1.48759335e-01,
-6.80063153e-03, -2.44185757e-02, -1.68685019e-01,
…
…
’
然後就可以想做什麼就做什麼了。