Tensorflow 模型參數保存方法
Tensorflow 保存模型參數分爲兩種,第一種是未剪枝參數,第二種是剪枝後的參數。對於未剪枝的參數,可以直接保存在TXT,剪枝後的參數可以按照稀疏矩陣保存方法進行保存。
一、未剪枝參數保存
讀取參數之前需要保存模型和對應的參數,保存方法按照如下:
def save_model(self): saver = tf.train.Saver() save_path = saver.save(self.sess,"save/model.ckpt") |
調用這操作後會在save下生成對應的模型,從ckpt文件中就能讀取到參數
保存參數只需要讀取模型的參數名和參數數據,代碼示例爲:
import TensorFlow as tf model_dir = "save/" ckpt = tf.train.get_checkpoint_state(model_dir) ckpt_path = ckpt.model_checkpoint_path # importing graph reader = tf.train.NewCheckpointReader(ckpt_path) all_variables = reader.get_variable_to_shape_map() |
對應行數分別爲
1.文件目錄
2.讀取文件目錄的ckpt文件
3.得到ckpt文件名
4.用reader來得到對應的參數字典和參數數據
5.所有參數和對應參數的size
得到參數後可以用get_tensor得到tensor數據
parameter_data = reader.get_tensor(key) |
保存在txt需要一個個數據保存,因爲數據長度很大如果直接把parameter_data保存會是大量的省略號具體操作如下
for key in all_variables.keys(): # print(key,all_variables[key]) parameter_data = reader.get_tensor(key) print('**************** save', key ,' succeed******************* shape:',parameter_data.shape) data_shape = parameter_data.shape pf.write(str(key)) pf.write(',data shape:') pf.write(str(all_variables[key])) pf.write('\n') if len(data_shape) == 0: pf.write(str(parameter_data)) # save 1-D data format if len(data_shape) == 1: pf.write('{') for i in range(parameter_data.shape[0]): pf.write(str(parameter_data[i])) if i < parameter_data.shape[0] - 1: pf.write(',') else: pf.write('}') pf.write('\n') # save 2-D data format if len(data_shape) == 2: pf.write('{') for i in range(parameter_data.shape[0]): for j in range(parameter_data.shape[1]): pf.write(str(parameter_data[i][j])) if i < parameter_data.shape[0] - 1: pf.write(',') else: pf.write('}') pf.write('\n') # save 4-D data format if len(data_shape) == 4: pf.write('{') for i in range(parameter_data.shape[0]): for j in range(parameter_data.shape[1]): for k in range(parameter_data.shape[2]): for l in range(parameter_data.shape[3]): pf.write(str(parameter_data[i][j][k][l])) if i < parameter_data.shape[0] - 1: pf.write(',') else: pf.write('}') print('\n') pf.close() |
具體的思路就是,爲了移植到C作爲define 方便,按照define 的定義方式,初始和結尾用{ },中間一個個保存data 和data 後用間隔分隔,當然也可以修改保存的頭
pf.write(str(key)) pf.write(',data shape:') pf.write(str(all_variables[key])) pf.write('\n') |
這幾行可以修改成直接使用的方式 #define PATAMETER_SRT DDATA這種格式
全部代碼如下
import tensorflow as tf import numpy as np from tensorflow.python import pywrap_tensorflow from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file def view_parameter(model_dir): ckpt = tf.train.get_checkpoint_state(model_dir) # print("ckpt :",ckpt) ckpt_path = ckpt.model_checkpoint_path reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path) param_dict = reader.get_variable_to_shape_map() # importing graph reader = tf.train.NewCheckpointReader(ckpt_path) all_variables = reader.get_variable_to_shape_map() print(all_variables) # can be view data and data sizes for key, val in param_dict.items(): try: print(key, val) data = reader.get_tensor(key) print(data) # print_tensors_in_checkpoint_file(ckpt_path, tensor_name=key, all_tensors=False, all_tensor_names=False) except: pass return all_variables,ckpt_path def save_parameter_txt(model_dir="save/"): model_dir = model_dir ckpt = tf.train.get_checkpoint_state(model_dir) ckpt_path = ckpt.model_checkpoint_path # importing graph reader = tf.train.NewCheckpointReader(ckpt_path) all_variables = reader.get_variable_to_shape_map() print(all_variables) pf = open('parameter.txt', 'w') # loop save non-None data in txt for key in all_variables.keys(): # print(key,all_variables[key]) parameter_data = reader.get_tensor(key) print('**************** save', key ,' succeed******************* shape:',parameter_data.shape) data_shape = parameter_data.shape pf.write(str(key)) pf.write(',data shape:') pf.write(str(all_variables[key])) pf.write('\n') if len(data_shape) == 0: pf.write(str(parameter_data)) # save 1-D data format if len(data_shape) == 1: pf.write('{') for i in range(parameter_data.shape[0]): pf.write(str(parameter_data[i])) if i < parameter_data.shape[0] - 1: pf.write(',') else: pf.write('}') pf.write('\n') # save 2-D data format if len(data_shape) == 2: pf.write('{') for i in range(parameter_data.shape[0]): for j in range(parameter_data.shape[1]): pf.write(str(parameter_data[i][j])) if i < parameter_data.shape[0] - 1: pf.write(',') else: pf.write('}') pf.write('\n') # save 4-D data format if len(data_shape) == 4: pf.write('{') for i in range(parameter_data.shape[0]): for j in range(parameter_data.shape[1]): for k in range(parameter_data.shape[2]): for l in range(parameter_data.shape[3]): pf.write(str(parameter_data[i][j][k][l])) if i < parameter_data.shape[0] - 1: pf.write(',') else: pf.write('}') print('\n') pf.close() if __name__ == '__main__': save_parameter_txt() |
二、保存剪枝後的稀疏矩陣
剪枝後的稀疏矩陣存儲需要單獨拿出參數進行相乘再將數據存儲爲{row}{col}{data}
如果以全連接層爲例,均爲二維操作
示例如下
def weights_csc_matrix(weights_matrix, mask_matrix): weights_matrix = np.multiply(weights_matrix,mask_matrix) row_data = [] col_data = [] weights_data = [] print(weights_matrix) for i in range(weights_matrix.shape[0]): for j in range(weights_matrix.shape[1]): if weights_matrix[i][j] > 0: row_data.append(i) col_data.append(j) weights_data.append(weights_matrix[i][j]) else: continue print(row_data,col_data,weights_data) return row_data,col_data,weights_data |
得到的輸出結果再用TXT存儲原理和上面類似。