tensorflow2解析模型參數

tensorflow2 解析模型參數

introduction

CNN模型是計算機視覺裏面常用的工具了,訓模師訓好模型可能還需要其他的操作,比如可能做剪枝,或者量化,需要對模型的參數做一些操作。這時候就需要解析模型的參數了。這篇文章主要敘述一下,在tensorflow2 下怎麼解析模型參數。

step by step

1. 先構建並保存一個模型

tensorflow2 構建模型還是首選keras接口, 在模型保存方面有好幾個接口可以選擇,觸類旁通, 此處我習慣使用tf.save_model。

class PlainCNN(tf.keras.Model):
    def __init__(self,
                 kernel_initializer='glorot_normal'):
        super(PlainCNN, self).__init__()


        self.conv1= tf.keras.layers.Conv2D( filters=32,
                                           kernel_size=(3, 3),
                                           strides=2,
                                           padding='same',
                                           use_bias=False,
                                           kernel_initializer=kernel_initializer)

        self.conv2 = tf.keras.layers.Conv2D(filters=64,
                                            kernel_size=(3, 3),
                                            strides=2,
                                            padding='same',
                                            use_bias=False,
                                            kernel_initializer=kernel_initializer)

        self.conv3 = tf.keras.layers.Conv2D(filters=128,
                                            kernel_size=(3, 3),
                                            strides=2,
                                            padding='same',
                                            use_bias=False,
                                            kernel_initializer=kernel_initializer)
    @tf.function(input_signature=[tf.TensorSpec([None, None, None, 3], tf.float32)])
    def call(self, inputs, training=False):
        x1 = self.conv1(inputs)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        return x3
import numpy as np
model = PlainCNN()
image=np.zeros(shape=(1,48,48,3),dtype=np.float32)
x=model(image)

# 將會打印出模型的信息:
# the result shape is : (1, 6, 6, 128)
print('the result shape is :', x.shape)
##保存模型爲 tmp_model
tf.saved_model.save(model,'tmp_model')

此時我們有了tmp_model 模型目錄

2. 加載模型
model=tf.saved_model.load('./tmp_model')
###這時model可以理解爲上面的代碼,可以直接inference, 也可以依次獲取每個變量的variables,####或者trainable_variables,等等
print(model.conv1.variables)

輸出實例
‘ListWrapper([<tf.Variable ‘conv2d/kernel:0’ shape=(3, 3, 3, 32) dtype=float32, numpy=
array([[[[ 8.40592012e-02, -1.82091370e-02, 6.67047426e-02,
-1.57765336e-02, 5.53220101e-02, 6.69927672e-02,
1.18893180e-02, -3.55695821e-02, 2.96269413e-02,
6.24449886e-02, 8.93794820e-02, 1.48759335e-01,
-6.80063153e-03, -2.44185757e-02, -1.68685019e-01,


然後就可以想做什麼就做什麼了。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章