神经翻译笔记4扩展a第二部分. RNN在TF2.0中的实现方法略览

神经翻译笔记4扩展a第二部分. RNN在TF2.0中的实现方法略览


与TF1.x的实现思路不同,在TF2.0中,RNN已经不再是个函数,而是一个封装好的类,各种RNN和RNNCell与顶层抽象类Layer的关系也更加紧凑(需要说明的是说Layer顶层并非说它直接继承自object,而是从……功能的角度,我觉得可以这么说。真实实现里的继承关系是Layer --> Module --> AutoTrackable --> Trackable --> object)。但是另一方面,感觉新的版本里各个类的关系稍微有些杂乱,不知道后面会不会进一步重构。TF2.0的RNN相关各类关系大致如下图所示

在这里插入图片描述

相关基类

tf.keras.layers.Layer

与TF1.14的实现基本相同,不再赘述

recurrent.DropoutRNNCellMixin

与之类似的类在TF1.x中以tf.nn.rnn_cell.DropoutWrapper形式出现,但当时考虑到还没涉及到RNN的dropout就没有引入,没想到在这里还是要说一说。TF2的实现比TF1的实现要简单一些,这个类只是维护两个dropout mask,一个是用于对输入的mask,一个用于对传递状态的mask(严格说是四个,在另一个维度上还考虑是对静态图的mask还是对eager模式的mask)。实现保证mask只被创建一次,因此每个batch使用的mask都相同

RNNCell相关

无论是官方给出的文本分类教程,还是我自己从TF1.x改的用更底层API实现的代码,实际上都没有用到Cell相关的对象。但是为了完整起见(毕竟暴露的LSTM类背后还需要LSTMCell类对象作为自己的成员变量),这里还是稍作介绍

LSTMCell

本文以LSTM为主,因此先从LSTMCell说起。与TF1.x不同,在2.x版本里,LSTMCell允许传入一个implement参数,默认为1,标记LSTM各门和输出、状态的计算方式。当取默认的1时,计算方式更像是论文中的方式,逐个计算各个门的结果;而如果设为2,则使用TF1.x中组合成矩阵一并计算的方式。此外,由于LSTMCell还继承了前述DropoutRNNCellMixin接口,因此可以在call里对输入和上一时间步传来的状态做dropout。注意由于LSTM有四个内部变量i\boldsymbol{i}f\boldsymbol{f}o\boldsymbol{o}c~\tilde{\boldsymbol{c}},因此需要各自生成四个不同的dropout mask

PeepholeLSTMCell

只是改写了LSTMCell内部变量的计算逻辑,参见在TF1.x部分的介绍

StackedRNNCells

与TF1.x中的MultiRNNCell类似

AbstractRNNCell

纯抽象类,类似TF1的RNNCell,如果用户自己实现一个RNNCell需要 可以继承于它。不过有趣的是内置的三种RNN实现所使用的Cell:SimpleRNNCellGRUCellLSTMCell均直接继承自Layer

RNN相关

tf.keras.layers.RNN

所有后续RNN相关类的基类,承担TF1.x中static_rnndynamic_rnn的双重功能,主要逻辑分别集中在初始化函数__init__buildcall中(__call__也有一些逻辑,但是只针对某些特殊情况)

RNN在初始化时传入的参数个人感觉相对来讲不如1.x直观。其允许传入的参数包括

  • cell:一种RNNCell的对象,也可以是列表或元组。当传入的参数为列表或元组时,会打包组合为StackedRNNCells类对象
  • return_sequences:默认RNN只返回最后一个时间步的输出。当此参数设为True时,返回每个时间步的输出
  • return_state:当此参数设为True时,返回最终状态
  • go_backwards:当此参数设为True时,将输入逆序处理
  • stateful:当此参数设为True时,每个batch第i个样本的最终状态会作为下个batch第i个样本的初始状态
  • unroll:当此参数设为True时,相当于1.x版本中的static rnn,网络被展开。文档认为展开网络可以加速RNN,但显然代价是使用的显存资源会变多
  • time_major:当此参数设为True时,第一个维度为时间维;否则为batch维
  • zero_output_for_mask:没有在接口中直接暴露出来,而是隐藏在**kwargs中。当此参数设为True时,mask对应的时间步输出都为0,否则照搬前一个时间步的输出

build实际只是调用cellbuild方法,并做一些校验

def build(self, input_shape):
    step_input_shape = get_step_input_shape(input_shape)
    if not self.cell.built:
        self.cell.build(step_input_shape)
    self._set_state_spec(state_size)
    if self.stateful:
        self.reset_states()
    self.built = True

call的核心是调用keras后端方法keras.backend.rnnK.rnn

def _process_inputs(inputs, initial_state, ...):
    if initial_state is not None:
        pass
    elif self.stateful:
        initial_state = self.states
    else:
        get_initial_state_fn = getattr(self.cell, 'get_initial_state', None)
        if get_initial_state_fn:
            initial_state = get_initial_state_fn()
        else:
            initial_state = zero_state
    return inputs, initial_state, ...

def call(self, inputs, ...):
    inputs, initial_state, ... = self._process_inputs(inputs, initial_state, ...)
    def step(inputs, states):
        output, new_states = self.cell.call(inputs, states)
    last_output, outputs, states = K.rnn(step, inputs, initial_state, ...)
    if self.stateful:
        updates = [assign_op(old, new) for old, new in zip(self.states, states)
        self.add_update(updates)
    if self.return_sequences:
        output = outputs
    else:
        output = last_output
    
    if self.return_state:
        return to_list(output) + states
    return output

K.rnn对RNN是否展开(unroll)和是否需要mask有不同的逻辑,这里只列出不展开且有mask的逻辑。个人感觉和1.x版本中dynamic_rnn的实现方法大同小异

def rnn(step_function, inputs, initial_states, ...):
    # 转换成time major
    inputs = swap_batch_timestep(inputs)
    mask = swap_batch_timestep(mask)
    time_steps_t = inputs[0].shape[0]
   
    input_ta = TensorArray(inputs)
    output_ta = TensorArray(shape=inputs[0].shape)
    mask_ta = TensorArray(mask)
    states = tuple(initial_states)
    prev_output = 0
    time = 0
    while time < time_steps_t:
        current_input = input_ta[time]
        mask_t = mask_ta[time]
        output, new_states = step_function(current_input, states)
        mask_output = 0 if zero_output_for_mask else prev_output
        new_output = where(mask, output, mask_output)
        new_states = where(mask, new_states, states)
        output_ta.append(new_output)
        prev_output, states = new_output, new_states
        time += 1
    return output_ta[-1], output_ta, states

recurrent.LSTM

与父类相比实际上只额外做了两件事

  • 初始化时cell固定为LSTMCell
  • 调用父类的call之前先重置两个dropout mask

recurrent_v2.LSTM

使用tf.keras.layers.LSTM类对象时实际使用的类,之所以带“v2”是因为整合了CuDNN的实现,所以理论上速度会更快,效率会更高。不过使用时需加入如下两行代码

physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)

tf.keras.layers.Bidirectional

与RNN类似地,在TF2.0里双向RNN也不再实现为函数,而是实现为一个Layer对象的包装器,为Layer对象提供一定的额外功能。由于Bidirectional也是间接继承自Layer类,因此其大部分逻辑也是蕴含在call方法中

初始化Bidirectional主要需要传入一个Layer类对象layer——不过从实现来看,这个类对象应该还是要是RNN或者其子类的对象。可选的三个字段包括:

  • merge_mode,指定正向和反向RNN的输出如何组合,可以是如下几种选择:求和sum、逐元素相乘mul、直接相连concat、求均值ave或直接返回两个输出组成的一个列表None
  • weights,指定两个RNN的初始化权重
  • backward_layer:允许用户直接传入已经反向的RNN。如果backward_layerNone(默认情况),Bidirectional在初始化时会先根据layer对象的config重构一个RNN,再使用相同的配置构建对应的反向RNN。Bidrectional会强制让自己的两个RNN成员对被mask掉的部分输出为0(zero_output_for_mask强制为True

Bidirectionalbuild实际上就是调用两个RNN成员的build。对应地,call方法也是调用两个RNN成员的call然后根据指定的merge_mode组合输出。源代码看上去略长是因为处理了多个输入和初始状态不为空的情况,而常见的单输入无初始状态下,逻辑相对直观,大致如下:

def call(self, inputs):
    y = self.forward_layer(inputs, **kwargs)
    y_rev = self.backward_layer(inputs, **kwargs)

    if self.return_state:
        states = y[1:] + y_rev[1:]
        y, y_rev = y[0], y_rev[0]
    if self.return_sequences:
        y_rev = K.reverse(y_rev, 1)

    if self.merge_mode == 'concat':
        output = K.concatenate([y, y_rev])
    elif self.merge_mode == 'sum':
        output = y + y_rev
    elif self.merge_mode == 'ave':
        output = (y + y_rev) / 2
    elif self.merge_mode == 'mul':
        output = y * y_rev
    elif self.merge_mode is None:
        output = [y, y_rev]
    else:
        raise ValueError
    
    if self.return_state:
        if self.merge_mode is None:
            return output + states
        return [output] + states
    return output

后记

与前一篇文章相比,本文显得有些粗糙。不过这也是意料之中的事情:LSTM的原理并不会因为它是TF1还是TF2发生变化,因此实现也不会有太大的变化,变的只会是类的组织方式。料想下一篇讨论PyTorch的文章,更多也会集中在结构设计上,毕竟具体实现已经在前一篇文章里描述得差不多了

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章