Tensorflow 模型參數保存方法

Tensorflow 模型參數保存方法

Tensorflow 保存模型參數分爲兩種,第一種是未剪枝參數,第二種是剪枝後的參數。對於未剪枝的參數,可以直接保存在TXT,剪枝後的參數可以按照稀疏矩陣保存方法進行保存。

 

一、未剪枝參數保存

讀取參數之前需要保存模型和對應的參數,保存方法按照如下:

def save_model(self):

  saver = tf.train.Saver()

  save_path = saver.save(self.sess,"save/model.ckpt")

調用這操作後會在save下生成對應的模型,從ckpt文件中就能讀取到參數

保存參數只需要讀取模型的參數名和參數數據,代碼示例爲:

import TensorFlow as tf

 

model_dir = "save/"

ckpt = tf.train.get_checkpoint_state(model_dir)

ckpt_path = ckpt.model_checkpoint_path

# importing graph

reader = tf.train.NewCheckpointReader(ckpt_path)

all_variables = reader.get_variable_to_shape_map()

對應行數分別爲

1.文件目錄

2.讀取文件目錄的ckpt文件

3.得到ckpt文件名

4.用reader來得到對應的參數字典和參數數據

5.所有參數和對應參數的size

 

得到參數後可以用get_tensor得到tensor數據

parameter_data = reader.get_tensor(key)

 

保存在txt需要一個個數據保存,因爲數據長度很大如果直接把parameter_data保存會是大量的省略號具體操作如下

for key in all_variables.keys():

  # print(key,all_variables[key])

  parameter_data = reader.get_tensor(key)

  print('**************** save', key ,' succeed******************* shape:',parameter_data.shape)

  data_shape = parameter_data.shape

  pf.write(str(key))

  pf.write(',data shape:')

  pf.write(str(all_variables[key]))

  pf.write('\n')

  if len(data_shape) == 0:

      pf.write(str(parameter_data))

 

  # save 1-D data format

  if len(data_shape) == 1:

      pf.write('{')

      for i in range(parameter_data.shape[0]):

          pf.write(str(parameter_data[i]))

          if i < parameter_data.shape[0] - 1:

              pf.write(',')

          else:

              pf.write('}')

      pf.write('\n')

 

  # save 2-D data format

  if len(data_shape) == 2:

      pf.write('{')

      for i in range(parameter_data.shape[0]):

          for j in range(parameter_data.shape[1]):

              pf.write(str(parameter_data[i][j]))

              if i < parameter_data.shape[0] - 1:

                  pf.write(',')

              else:

                  pf.write('}')

      pf.write('\n')

 

  # save 4-D data format

  if len(data_shape) == 4:

      pf.write('{')

      for i in range(parameter_data.shape[0]):

          for j in range(parameter_data.shape[1]):

              for k in range(parameter_data.shape[2]):

                  for l in range(parameter_data.shape[3]):

                      pf.write(str(parameter_data[i][j][k][l]))

                      if i < parameter_data.shape[0] - 1:

                          pf.write(',')

                      else:

                          pf.write('}')

      print('\n')

pf.close()

具體的思路就是,爲了移植到C作爲define 方便,按照define 的定義方式,初始和結尾用{ },中間一個個保存data 和data 後用間隔分隔,當然也可以修改保存的頭

  pf.write(str(key))

  pf.write(',data shape:')

  pf.write(str(all_variables[key]))

  pf.write('\n')

這幾行可以修改成直接使用的方式 #define PATAMETER_SRT DDATA這種格式

 

全部代碼如下

import tensorflow as tf

import numpy as np

from tensorflow.python import pywrap_tensorflow

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

 

def view_parameter(model_dir):

  ckpt = tf.train.get_checkpoint_state(model_dir)

  # print("ckpt :",ckpt)

  ckpt_path = ckpt.model_checkpoint_path

 

  reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)

  param_dict = reader.get_variable_to_shape_map()

  # importing graph

  reader = tf.train.NewCheckpointReader(ckpt_path)

  all_variables = reader.get_variable_to_shape_map()

  print(all_variables)

  # can be view data and data sizes

  for key, val in param_dict.items():

      try:

          print(key, val)

          data = reader.get_tensor(key)

          print(data)

  #       print_tensors_in_checkpoint_file(ckpt_path, tensor_name=key, all_tensors=False, all_tensor_names=False)

      except:

          pass

  return all_variables,ckpt_path

 

def save_parameter_txt(model_dir="save/"):

  model_dir = model_dir

  ckpt = tf.train.get_checkpoint_state(model_dir)

  ckpt_path = ckpt.model_checkpoint_path

 

  # importing graph

  reader = tf.train.NewCheckpointReader(ckpt_path)

  all_variables = reader.get_variable_to_shape_map()

  print(all_variables)

 

  pf = open('parameter.txt', 'w')

  # loop save non-None data in txt

  for key in all_variables.keys():

      # print(key,all_variables[key])

      parameter_data = reader.get_tensor(key)

      print('**************** save', key ,' succeed******************* shape:',parameter_data.shape)

      data_shape = parameter_data.shape

      pf.write(str(key))

      pf.write(',data shape:')

      pf.write(str(all_variables[key]))

      pf.write('\n')

      if len(data_shape) == 0:

          pf.write(str(parameter_data))

 

      # save 1-D data format

      if len(data_shape) == 1:

          pf.write('{')

          for i in range(parameter_data.shape[0]):

              pf.write(str(parameter_data[i]))

              if i < parameter_data.shape[0] - 1:

                  pf.write(',')

              else:

                  pf.write('}')

          pf.write('\n')

 

      # save 2-D data format

      if len(data_shape) == 2:

          pf.write('{')

          for i in range(parameter_data.shape[0]):

              for j in range(parameter_data.shape[1]):

                  pf.write(str(parameter_data[i][j]))

                  if i < parameter_data.shape[0] - 1:

                      pf.write(',')

                  else:

                      pf.write('}')

          pf.write('\n')

 

      # save 4-D data format

      if len(data_shape) == 4:

          pf.write('{')

          for i in range(parameter_data.shape[0]):

              for j in range(parameter_data.shape[1]):

                  for k in range(parameter_data.shape[2]):

                      for l in range(parameter_data.shape[3]):

                          pf.write(str(parameter_data[i][j][k][l]))

                          if i < parameter_data.shape[0] - 1:

                              pf.write(',')

                          else:

                              pf.write('}')

          print('\n')

  pf.close()

 

if __name__ == '__main__':

  save_parameter_txt()

 

二、保存剪枝後的稀疏矩陣

剪枝後的稀疏矩陣存儲需要單獨拿出參數進行相乘再將數據存儲爲{row}{col}{data}

如果以全連接層爲例,均爲二維操作

示例如下

def weights_csc_matrix(weights_matrix, mask_matrix):

  weights_matrix = np.multiply(weights_matrix,mask_matrix)

  row_data = []

  col_data = []

  weights_data = []

  print(weights_matrix)

 

  for i in range(weights_matrix.shape[0]):

      for j in range(weights_matrix.shape[1]):

          if weights_matrix[i][j] > 0:

              row_data.append(i)

              col_data.append(j)

              weights_data.append(weights_matrix[i][j])

          else:

              continue

  print(row_data,col_data,weights_data)

  return row_data,col_data,weights_data

得到的輸出結果再用TXT存儲原理和上面類似。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章