【Tensorflow】用於處理checkpoint中參數名稱與矩陣數值的工具類

0x00 前言

目前對於Tensorflow的模型參數文件,我們處理起來沒有Pytorch的參數文件那樣方便,
並且現在任務中有個需求,要在“某幾個參數矩陣中,將特定行的參數複製到某些其他行”。
Pytorch的話就還好,因爲畢竟是一羣tensor被一個OrderDict包裝起來的Python基本數據結構。
同樣的事情,在Tensorflow中處理起來會比較麻煩,於是考慮實現這個工具類 CheckpointMonitor 來提高處理效率。

0x01 效果及API

  • 支持從Tensorflow的模型參數文件ckpt中修改任意參數矩陣
    • 可以批量或單獨修改參數名,保持參數的各項屬性不變
      • 批量修改的方式爲:允許傳入一個函數,對於輸入的參數名均會根據自定義函數修改爲輸出的參數名稱
      • 例如,在Tensorflow和PyTorch參數互轉的時候,需要用到這一步
    • 可以將修改後的參數存回Tensorflow(下圖1)或存成PyTorch(下圖2)
    • 可以篩選、檢查、修改任意參數矩陣的全部或部分數值,對於工具類,全程以numpy的數據格式處理即可
    • 自動維護模型文件中的參數順序,也可以在已有的模型參數基礎上做擴充,例如參數拼接

0x02 API列表

  • 初始化傳參__init__(checkpoint_path)爲checkpoint路徑
  • list_variables() 展示當前checkpoint中的所有參數即shape
  • list_target_variables(pattern)list_variables,展示篩選後的參數列表(圖3)
  • get_var_data(var_name) 獲得模型文件中對應參數名的參數,格式爲numpy
  • save_model(path, method='tf) 模型文件存回Tensorflow或Pytorch
  • modify_var_name(old_name, new_name) 修改參數名
  • modify_var_names(rename_func) 批量修改參數名
  • modify_var_data(var_name, var_data) 修改參數的值
  • 目前是這些,以後有需求可能會再加(例如加密解密、模型輕量化的工具都可以整合到這個類裏)

0x03 requirements

  • python >= 3.6(沒測試低版本)
  • tensorflow >= 1.15(沒測試低版本)
  • torch >= 1.4 (如果需要存成torch則需要)
  • numpy

0x04 Source Code

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = ""
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
from collections import OrderedDict


class CheckpointMonitor(object):
    """
    # CPU mode
    import os
    os.environ['CUDA_LAUNCH_BLOCKING'] = ""
    os.environ['CUDA_VISIBLE_DEVICES'] = ""
    """
    def __init__(self, checkpoint_path=None):
        if checkpoint_path is None:  # default path for testing
            checkpoint_path = '/data/sharedata/model_files/model.ckpt-250042'
        
        self.saver = None
        self.graph = None
        self.dump_path = './'
        self.checkpoint_path = checkpoint_path
        self.default_dump_name = 'my_modified_model'
        self.var_name_list = []
        self.var_shape_dict = OrderedDict()
        self.var_data_dict = OrderedDict()
        self.init_vars()
    
    def reload(self, checkpoint_path=None):
        self.__init__(checkpoint_path=checkpoint_path)
    
    def init_vars(self, checkpoint_path=None):
        if checkpoint_path is None:
            checkpoint_path = self.checkpoint_path
        self.var_shape_dict = OrderedDict(
            self.list_variables(checkpoint_path))
        self.var_name_list = list(self.var_shape_dict.keys())
        for var_name in self.var_name_list:
            # print(var_name)
            var_data = self.get_var_data(var_name, checkpoint_path)
            # dict(str, np.array)
            self.var_data_dict.update({var_name: var_data})
    
    def sort_var_dicts(self):
        self.var_data_dict = OrderedDict(
            [(var_name, self.var_data_dict[var_name]) 
             for var_name in self.var_name_list])
        self.var_shape_dict = OrderedDict(
            [(var_name, self.var_shape_dict[var_name]) 
             for var_name in self.var_name_list])
    
    def list_variables(self, checkpoint_path=None):
        # get all variables in form of tuple(name, shape) in checkpoint
        if checkpoint_path is None:
            checkpoint_path = self.checkpoint_path
        # return a list of (var_name, shape)
        return tf.contrib.framework.list_variables(checkpoint_path)
    
    def list_target_variables(self, pattern, checkpoint_path=None):
        if checkpoint_path is None:
            if self.var_shape_dict.__len__() != 0:
                # lazy loading
                var_list = self.var_shape_dict.items()
                return [(name, shape) for (name, shape) 
                        in var_list if pattern in name]
            else:  # load for cold-booting
                checkpoint_path = self.checkpoint_path
        var_list = self.list_variables(checkpoint_path)
        return [(name, shape) for (name, shape) in var_list if pattern in name]
    
    def get_var_data(self, var_name, checkpoint_path=None):
        # load variable from target checkpoint with the name as var_name
        if checkpoint_path is None:
            if self.var_data_dict.__len__() != 0:
                # lazy loading
                return self.var_data_dict.get(var_name)
            checkpoint_path = self.checkpoint_path
        # return the variable object (np.array)
        return tf.contrib.framework.load_variable(checkpoint_path, var_name)
    
    @staticmethod
    def generate_rename_func(old_name_list, new_name_list):
        def fn(var_name):
            if var_name in old_name_list:
                return new_name_list[old_name_list.index(var_name)]
            return var_name
        return fn
    
    def modify_var_name(self, old_name, new_name, inplace=True):
        var_index = self.var_name_list.index(old_name)
        self.var_name_list[var_index] = new_name
        self.var_data_dict[new_name] = self.var_data_dict[old_name]
        self.var_shape_dict[new_name] = self.var_shape_dict[old_name]
        del self.var_data_dict[old_name]
        del self.var_shape_dict[old_name]
        if inplace:
            self.sort_var_dicts()
    
    def modify_var_names(self, rename_func=None):
        # modify var_names in batch, with a feed function `rename_func`
        if rename_func is None:
            rename_func = lambda _name: _name

        with tf.Session() as sess:
            for var_index, var_name in enumerate(self.var_name_list): 
                # get variable values, in form of np.array
                new_name = rename_func(var_name)
                if new_name != var_name:
                    self.modify_var_name(var_index, new_name, inplace=False)
                    print('Re-naming {} to {}.'.format(var_name, new_name))
            self.sort_var_dicts()
    
    def modify_var_data(self, var_name, var_data):
        assert isinstance(var_data, np.ndarray)
        if var_name not in self.var_name_list:
            print("Invalid variable name:{}".format(var_name))
            print("You can get avaliable variable names by calling list_variables()")
        var_index = self.var_name_list.index(var_name)
        self.var_shape_dict[var_name] = list(var_data.shape)
        self.var_data_dict[var_name] = var_data
    
    def generate_var_dict_for_torch(self, var_list=None):
        if var_list is None:
            var_list = self.var_data_dict.items()
        torch_model_dict = OrderedDict()
        for var_name, var_data in var_list:
            var = torch.tensor(var_data)
            torch_model_dict.update({var_name: var})
        return torch_model_dict
    
    def generate_var_list_for_saver(self, var_list=None):
        if var_list is None:
            var_list = self.var_data_dict.items()
        saver_var_list = []
        with tf.Session() as sess:
            for var_name, var_data in var_list:
                var = tf.Variable(var_data, name=var_name)
                saver_var_list.append(var)
        return saver_var_list
    
    def save_model(self, new_checkpoint_path=None, model_name=None, method='pt'):
        if new_checkpoint_path is None:
            new_checkpoint_path = self.dump_path
        if not os.path.exists(new_checkpoint_path):
            os.makedirs(new_checkpoint_path)
        if model_name is None:
            model_name = self.default_dump_name
        checkpoint_path = os.path.join(
            new_checkpoint_path, model_name)
        
        method_dict = {
            'pt': self.save_model_as_pt,
            'tf': self.save_model_as_tf,
            'ckpt': self.save_model_as_tf,
            'torch': self.save_model_as_pt,
            'pytorch': self.save_model_as_pt,
            'tensorflow': self.save_model_as_tf,
        }
        method_dict[method](checkpoint_path)
    
    def save_model_as_pt(self, checkpoint_path):
        import torch
        var_dict = self.generate_var_dict_for_torch()
        checkpoint = OrderedDict({'model': var_dict})
        torch.save(checkpoint, checkpoint_path + '.pt')
        print("Checkpoint saving finished !\n{}".format(
            checkpoint_path + '.pt'))
    
    def save_model_as_tf(self, checkpoint_path):
        with tf.Session() as sess:
            var_list = self.generate_var_list_for_saver()
            # Construct the Saver
            self.saver = tf.train.Saver(var_list=var_list)
            # Necessary! Call the initializer at the beginning.
            sess.run(tf.global_variables_initializer())
            self.saver.save(sess, checkpoint_path)
            print("Checkpoint saving finished !\n{}".format(
                checkpoint_path))

0x05 效果展示

圖1 讀取原TF模型→修改單值→存回→讀取新TF模型→檢查修改

圖1 讀取原TF模型→修改單值→存回→讀取新TF模型→檢查修改

圖2 讀取原TF模型→修改單值→存成Pytorch模型→讀取新PyTorch模型→檢查修改

圖2 讀取原TF模型→修改單值→存成Pytorch模型→讀取新PyTorch模型→檢查修改

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章