【Tensorflow】自主實現部分連接層(Partial Connect Layer)

0x00 前言

通常而言,在NLP領域的機器學習中時常會有這樣的需求:argmaxP(wiθ)
較爲常見的做法是有多少個單詞就做一個多少維的全連接層加softmax,但是,
如果詞彙表 |V| 很大的情況下,會有大量的計算(例如目前的工作中,詞彙表的數量爲80k),
在我們已知只需要計算其中某些詞語的時候,完全可以屏蔽掉其中多餘的計算,
(即使增加一個mask把非計算部分都設爲 0 ,乘以 0 的乘法也不要做比較好),
所以這裏想到新構造一個 partial connect layer 來解決這個問題。
(因爲懶得更新TF,各位如果新配的TF環境完全可以試着去官方文檔裏搜搜 sparse_ 開頭的各種方法)

0x01 構造思路

就個人而言,其實比較懶,肯定不會過於深入透徹的去從很底層寫;
所以從思路上而言,就是拼了拼 embedding_layerfull_connect_layer
基本的 EmbLayerFCLayer 如下所示: (Thanks for @lhw446)

class EmbLayer(object):
    """embedding layer"""
    def __init__(self, word_emb_shape, word_embeddings_path=None, voc_path=None,
                 partition_strategy='mod', validate_indices=True, max_norm=None,
                 weight_decay=None, stop_gradient=False, show_no_word=True, name='emb'):
        # params
        self.partition_strategy = partition_strategy
        self.validate_indices = validate_indices
        self.word_embeddings = None
        self.max_norm = max_norm
        self.name = name

        with tf.name_scope('{}_def'.format(self.name)):
            scale = math.sqrt(2.0 / np.prod(word_emb_shape[-1]))
            self.word_embeddings = scale * np.random.standard_normal(size=word_emb_shape)

            if word_embeddings_path:
                assert voc_path is not None
                idx2word = pickle.load(open(voc_path, 'rb'))['idx2word']
                word2vec = pickle.load(open(word_embeddings_path, 'rb'))
                for idx in range(word_emb_shape[0]):
                    word = idx2word[idx]
                    if word in word2vec:
                        self.word_embeddings[idx] = scale * 0.1 * word2vec[word][:word_emb_shape[1]]
                    elif show_no_word:
                        print('word2vec no word {}: {}'.format(idx, word.encode('utf-8')))

            self.word_embeddings = tf.Variable(
                initial_value=self.word_embeddings, dtype=tf.float32, name='word_embeddings')

            if stop_gradient:
                self.word_embeddings = tf.stop_gradient(self.word_embeddings)

            if weight_decay:
                tf.add_to_collection(
                    'losses', tf.multiply(tf.nn.l2_loss(self.word_embeddings),
                                          weight_decay, name='weight_decay_loss'))

    def __call__(self, ids):
        with tf.name_scope('{}_cal'.format(self.name)):
            outputs = tf.nn.embedding_lookup(self.word_embeddings, ids,
                                             partition_strategy=self.partition_strategy,
                                             validate_indices=self.validate_indices,
                                             max_norm=self.max_norm,
                                             name=self.name)
            return outputs
class FulLayer(object):
    """Full Connect Layer"""
    def __init__(self, input_dim, output_dim, activation_fn=tf.sigmoid,
                 weight_decay=None, name="ful"):
        weight_shape = (input_dim, output_dim)
        self.activation_fn = activation_fn
        self.name = name

        with tf.name_scope('{}_def'.format(self.name)):
            # weight matrix
            scale = math.sqrt(2.0 / np.prod(weight_shape[:-1]))
            init_value = scale * np.random.standard_normal(size=weight_shape)
            self.weight = tf.Variable(init_value, dtype=tf.float32, name='weight')

            if weight_decay:
                tf.add_to_collection(
                    'losses', tf.multiply(tf.nn.l2_loss(self.weight),
                                          weight_decay, name='weight_decay_loss_w'))

            # bias vector
            self.bias = tf.Variable(
                initial_value=tf.constant(0.0, shape=[output_dim]),
                dtype=tf.float32, name='bias')
            if weight_decay:
                tf.add_to_collection(
                    'losses', tf.multiply(tf.nn.l2_loss(self.bias),
                                          weight_decay, name='weight_decay_loss_b'))

    def __call__(self, inputs):
        with tf.name_scope('{}_cal'.format(self.name)):
            shape = tf.concat([tf.shape(inputs)[:-1], [tf.shape(self.weight)[-1]]], 0)
            # shape = tf.concat([tf.shape(inputs)[:-1], [self.weight.shape[-1]]], 0)

            inputs = tf.reshape(inputs, [-1, tf.shape(inputs)[-1]])
            outputs = tf.add(tf.matmul(inputs, self.weight), self.bias)
            if self.activation_fn is not None:
                outputs = self.activation_fn(outputs)

            outputs = tf.reshape(outputs, shape)
            return outputs

可以看出,Embedding Layer 核心思想是維護一個 N 行,每行爲 C 維的參數空間,
輸入一個下標矩陣(矩陣可以爲一維或多維)——如 [0, 54, 900, 233] 或者 [[0, 2], [3, 6]] 之類,
返回一個多一維的矩陣——如(10) -> (10, C) 或者 (2, 3) -> (2, 3, C) 之類,
每個下標 i 被替換成了 C 維的該參數空間內的 第 i 行。

Full Connect Layer 則是維護一個Weight矩陣和一個bias向量,就是常說的 y=Wx+b 中的那個 Wb
對於輸入的 x ,進行上述計算後輸出結果,反傳時會同時更新兩者。

那麼就很明顯了,我們的 Partial Connect Layer 只需要在 Full Connect Layer 的基礎上,
維護兩個用來被 embedding_lookup 的參數空間 Wb ,在做乘法與加法之前,獲得需要計算的行下標,
採用 Embedding Layer 的方法獲得實際需要計算的 WWbb ,輸出 y=Wx+b 即可。

0x02 Source Code

class PartialLayer(object):
    """Partial Connect Layer"""
    def __init__(self, input_dim, output_dim, partial_dim,
                 activation_fn=tf.sigmoid, weight_decay=None, name="par"):
        self.partial_dim = partial_dim
        self.activation_fn = activation_fn
        self.weight_shape = (input_dim, output_dim)
        self.name = name

        with tf.name_scope('{}_def'.format(self.name)):
            # weight matrix
            scale = math.sqrt(2.0 / np.prod(self.weight_shape[:-1]))
            init_value = scale * np.random.standard_normal(size=self.weight_shape)
            self.weight = tf.Variable(init_value, dtype=tf.float32, name='weight')

            # bias vector
            self.bias = tf.Variable(
                initial_value=tf.constant(0.0, shape=[output_dim]),
                dtype=tf.float32, name='bias')

            if weight_decay:
                tf.add_to_collection(
                    'losses', tf.multiply(tf.nn.l2_loss(self.weight),
                                          weight_decay, name='weight_decay_loss_w'))
                tf.add_to_collection(
                    'losses', tf.multiply(tf.nn.l2_loss(self.bias),
                                          weight_decay, name='weight_decay_loss_b'))

            self.transposed_weight = tf.transpose(self.weight)
            self.transposed_bias = tf.expand_dims(self.bias, -1)

    def get_partial_weight(self, targets):
        return tf.nn.embedding_lookup(
            self.transposed_weight, targets,
            partition_strategy='mod',
            validate_indices=True,
            max_norm=None,
            name='partial_weight'
        )

    def get_partial_bias(self, targets):
        return tf.nn.embedding_lookup(
            self.transposed_bias, targets,
            partition_strategy='mod',
            validate_indices=True,
            max_norm=None,
            name='partial_bias'
        )

    def __call__(self, inputs, targets):
        """
        global weight is lstm_dim*2, n_words
        :param inputs: batch, seg_len, lstm_dim*2
        :param targets: batch, seg_len, can_len
        :return: batch, seg_len, can_len
        """
        with tf.name_scope('{}_cal'.format(self.name)):
            inputs = tf.expand_dims(inputs, -1)
            # print(inputs.shape)

            partial_weight = self.get_partial_weight(targets)
            partial_bias = self.get_partial_bias(targets)

            # batch*seg_len, can_len
            # print (inputs.shape, partial_weight.shape, partial_bias.shape)
            # print (type(inputs), type(partial_weight), type(partial_bias))

            outputs = tf.add(tf.matmul(partial_weight, inputs), partial_bias)
            if self.activation_fn is not None:
                outputs = self.activation_fn(outputs)

            # batch, seg_len, can_len
            outputs = tf.reshape(outputs, tf.shape(outputs)[:-1])
            return outputs
# layer initial in network's __init__()
from path.to.my.utils import options
self.partial_layer = PartialLayer(input_dim=2 * options.get('lstm_dim'),
                         output_dim=options.get('n_words'),
                         partial_dim=options.get('max_can_len'),
                         activation_fn=None,
                         weight_decay=self.options.get('weight_decay'),
                         name='pc_layer')
# example network construction.
def get_network(self):
    # [batch, seg_len + 2] -> [batch, seg_len + 2, emb_dim]
    word_emb = self.emb_layer(self.input_data)

    # [batch, seg_len(+2), emb_dim] -> [batch, seg_len, lstm_dim*2]
    forward_hidden, backward_hidden = self.lstm_layer(word_emb)
    context_hidden = tf.concat([forward_hidden, backward_hidden])

    # [batch, seg_len, lstm_dim*2] -> [batch, seg_len, can_len]
    partial_hidden = self.partial_layer(context_hidden, self.candidates)

0x03 後記

簡單的測評一下:
這種方法可以通過下標搜索快速獲取需要被計算的區域,大量降低計算,但計算過程中的拼接會產生少量的額外顯存消耗,
佔比爲candidates的數量和參數空間的行數之比(例如Vocabulary是80k,candidates是80,會額外產生0.1%的顯存消耗)。
此外,對於sparse族的方法我很感興趣,之後可以好好讀讀其效果和用法。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章