XLNet論文解讀+部分代碼解讀

XLNet Generalized Autoregressive Pretraining

Publisher:
作者: Zhilin Yang, Zihang Dai等
單位:Carnegie Mellon University, Google Brain
論文鏈接

XLNet:運行機制及和Bert的異同比較

飛躍芝麻街:XLNet 詳解

XLnet:比Bert更強大的預訓練模型

【BERT 系列 2】之 XLNet

中文XLNet預訓練模型

transformer-XL相對位置編碼示意圖

github上面別人提出針對雙向transformer-XL的疑問

1.Motivation

  作者認爲,Bert這種基於自編碼的具有雙向建模能力的模型性能比基於自迴歸建模的語言模型的性能要好。但是Bert因爲採用了Mask的訓練方式, 忽略了被Mask掉詞之間的依賴關係;同時因爲Bert是基於自編碼的,所以和基於自迴歸的模型相比較的,在面對生成任務的時候有缺陷;而且因爲Bert是基於transformer的,所以在序列長度方面有限制。所以作者就希望可以可以融合自編碼和自迴歸的優點,然後設計出來一個模型。

2.自迴歸語言模型和自編碼語言模型

2.1 自迴歸語言模型

  其實就是指RNN這一類的模型,這一類的模型的優化目標是最大化概率 p(x)=t=1Tp(xtX<t)p(x)=\sum_{t=1}^{T}p(x_t | X_{< t}) 或者最大化概率 p(x)=1t=Tp(xtX<t)p(x)=\sum_{1}^{t=T}p(x_t | X_{< t}) 。其中 X=(x1,...,xT)X=(x_1,...,x_T). 自迴歸語言模型有以下的優點:(其實就是rnn的優點)

  • 自迴歸語言模型的模型符合生成任務的需求,就是那種一個一個的生成我們需要的字符。類似人寫字一樣一個一個的寫出來.

  • 同時自迴歸語言模型可以學習要預測詞之間的的關係,因爲被預測的詞是根據上一個詞預測出來的。

然後也有一些缺點:

  • 但是自迴歸語言模型難以並行計算,以及無法提供一些語言理解任務中需要的雙向上下文信息。作者人爲, 雖然有雙線RNN的模型, 但是這些模型本身都是單向, 然後拼接成的雙向.

2.2 自編碼語言模型

  指的是以Bert爲代表這類語言模型,這類模型的特點是將輸入的數據破壞掉,然後通過剩下的輸入數據,再次重建出來被破壞掉的數據,即優化的是 p(xx~)p(x | \tilde{x}) 。然後自編碼語言麼的優點是:(就是Bert的優點, 其實就是和上面的RNN反着來吧)

  • 可以更好的提供上下文信息和提供並行計算

缺點是:

  • 從某些角度說,是在模型的輸入端加入噪聲,然後模型進行去除噪聲。因爲Bert引入了[MASK]符號,但是這個符號因爲不出現在微調階段,就導致了預訓練和微調之間的差距。而且作者人爲Bert這種mask操作導致模型學習不到被mask掉詞之間的關係.

  • 和預測的時候生成任務的生成不匹配,導致生成任務效果較差.

3.XLNet的主要改進

3.1 Permutation Language Modeling

  爲了引入自迴歸模型的優點,同時可以看到上下文,論文中提出了排列語言模型。例如序列[1,2,3,4],如果是爲了預測3,那麼我們怎麼樣才能使用自迴歸的方式讓3看到上下文呢?

  如圖3-1,打亂序列的順序之後輸入到模型裏面,就可以發現,當我們需要預測3的時候,我們能看到的只能是3前面的單詞,如果打亂序列順序之後,我們可以看到第一行3可以看到1,2;第二行3可以看到2,4;以此類推,那麼要預測的詞就可以看到上下文了。

1,2,3,42,4,3,11,4,3,231 \begin{aligned} &1, 2, 3, 4 \\ &2, 4, 3, 1 \\ &1, 4, 3, 2 \\ &圖3-1 \end{aligned}
在這裏插入圖片描述
圖 3-2: 論文提供的打亂順序的輸入示意圖, 圖中表示的都是3這個位置的單詞在不同輸入的順序下面可以看到的詞. 因爲採用了transformer-XL, 所以前面會有一個mem的記憶模塊.

3.2 雙流自我注意力結構

  雙流自我注意力結構應該說是對於Permutation Language Modeling的實現方式。

3.2.1 attention mask

  首先是加入的attention masks,因爲XLNet爲了保證預訓練的輸入和之後的微調的時候保證一致,不可能直接打亂序列的輸入順序。所以模型的輸入還是正常的序列順序。爲了實現打亂順序的需要,模型在進行attention的時候,進行了mask操作。

在這裏插入圖片描述
圖 3-3

  對於輸入順序如果是3-2-4-1的情況下,目前只看content stream的mask圖。對於第一行,代表的是1,因爲打亂順序之後,1相當於是最後輸入,那麼1可以看到所有的序列;對於第二行對應的是2,2相當於第二個輸入進去的,所以2能看到的是3和2,那麼對應的mask區域只有第二個和第三個可以不被mask掉。以此類推。

  但是這種處理方法,帶來了一個問題,例如依舊是3-2-4-1的輸入順序,在預測單詞4的時候,模型用這種mask方式可以看到4自己的信息;如果把4也mask掉,那麼在預測1的時候又看不到4的信息了。同時爲了解決Bert使用[MASK]代替被屏蔽單詞的問題,所以作者設計了一個新的結構去解決這個問題, 也就是加入了Query stream的另外一個流的自我注意力結構.

3.2.2 其餘的雙流操作

  雙流自我注意力結構分爲2部分,分別是內容流和查詢流。內容流,則是正常的transformer-XL的計算方式(和transformer-XL其實是略有不同的, 3.2.3會詳細的講和transformer-XL計算的差異),使用上面介紹的mask方法。查詢流中,attention中的Q只包含了輸入的位置信息, 而K,V則包含了內容信息,但是K,V包含的內容信息只包括輸入序列的位置t的前面的1-t個單詞的內容信息,並且不包含第t個單詞,所以和content的mask相比,對角線上的都mask掉了。
在這裏插入圖片描述
圖 3-4

  XLNet和Bert類似,採用了類似的“掩蓋”一部分輸入的序列,然後讓模型去預測。XLNet每次掩蓋的時候,選擇的都是打亂順序之後的序列的最後面的一部分,這樣也和自迴歸模型的模式類似。掩蓋的比例,根據論文的實驗,選擇的是1/7-1/6,也就是14.28%-16.67%。

在這裏插入圖片描述
圖 3-5

  XLNet採用的是transformer-XL的模型結構,對於位置編碼,直接採用transformer-XL的相對位置編碼方式。這裏進行的修改是相對段編碼。

  Bert採用的是絕對位置段編碼,但是因爲XLNet採用的是transformer-XL,所以需要使用上次的記憶數據,這裏也採用了相對句子段編碼。對於i計算j的注意力值的時候,如果ij來自同一段,那麼採用s+,否則採用s-,然後計算出來的aij直接加到正常的注意力值裏面即可。有兩個優點:

  • 增加了模型的泛化性

  • 保證了微調的時候遇到多個句子依舊可以正常使用

aij=(qi+b)Tsij,sij={s+,ijs,ij a_{ij} = (q_i + b)^T s_{ij}, 其中s_{ij} = \left\{\begin{matrix} s_+ &, ij來自同一段 \\ s_- &, ij來自不同段 \end{matrix}\right.

3.2.3 XLNet的雙向transformer-XL

  這裏針對的是內容流, 因爲雙流的查詢流的代碼沒看(因爲沒機器可以跑預訓練, 於是放棄了). 其實這裏的雙向transformer-XL和單向的transformer-XL的實現的差別主要還是在計算下面的公式的b和d上面:

Ai,jrel=ExiTWqTWk,EExja+ExiTWqTWk,RRijb+uTWk,EExjc+vTWk,RRijd A^{rel}_{i,j} = \underbrace{E_{x_i}^T W_q^T W_{k,E} E_{x_j}}_{a} + \underbrace{E_{x_i}^T W_q^T W_{k,R} \color{blue} R_{i-j}}_{b} + \underbrace{{\color{red} u^T} W_{k,E} E_{x_j}}_{c} + \underbrace{{\color{red}v^T} W_{k,R} \color{blue} R_{i-j}}_{d}

可以去看一下XLNet中生產位置信息的實現代碼:(下面截取的是huggingface的XLNet的pytorch版本的實現代碼, 和原版的tensorflow的基本完全一樣)

首先是生成 RijR_{i-j} 的部分:

    @staticmethod
    def positional_embedding(pos_seq, inv_freq, bsz=None):
        sinusoid_inp = torch.einsum('i,d->id', pos_seq, inv_freq)
        pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1)
        pos_emb = pos_emb[:, None, :]

        if bsz is not None:
            pos_emb = pos_emb.expand(-1, bsz, -1)

        return pos_emb

    def relative_positional_encoding(self, qlen, klen, bsz=None):
        """create relative positional encoding."""
        freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.float)
        inv_freq = 1 / torch.pow(10000, (freq_seq / self.d_model))

        if self.attn_type == 'bi':
            # beg, end = klen - 1, -qlen
            beg, end = klen, -qlen
        elif self.attn_type == 'uni':
            # beg, end = klen - 1, -1
            beg, end = klen, -1
        else:
            raise ValueError('Unknown `attn_type` {}.'.format(self.attn_type))

        if self.bi_data:
            fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.float)
            bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=torch.float)

            if self.clamp_len > 0:
                fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
                bwd_pos_seq = bwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)

            if bsz is not None:
                fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz//2)
                bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz//2)
            else:
                fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq)
                bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq)

            pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=1)
        else:
            fwd_pos_seq = torch.arange(beg, end, -1.0)
            if self.clamp_len > 0:
                fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
            pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)

        pos_emb = pos_emb.to(next(self.parameters()))
        return pos_emb

調用relative_positional_encoding來生成 RijR_{i-j}, 上面的代碼寫的比較複雜, 但是實際上我們需要關注的代碼只有下面這麼多:

    @staticmethod
    def positional_embedding(pos_seq, inv_freq, bsz=None):
        sinusoid_inp = torch.einsum('i,d->id', pos_seq, inv_freq)
        pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1)
        pos_emb = pos_emb[:, None, :]

        if bsz is not None:
            pos_emb = pos_emb.expand(-1, bsz, -1)

        return pos_emb

    def relative_positional_encoding(self, qlen, klen, bsz=None):
        """create relative positional encoding."""
        freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.float)
        inv_freq = 1 / torch.pow(10000, (freq_seq / self.d_model))

        if self.attn_type == 'bi':
            # beg, end = klen - 1, -qlen
            beg, end = klen, -qlen
        elif self.attn_type == 'uni':
            ...
        else:
            raise ValueError('Unknown `attn_type` {}.'.format(self.attn_type))

        if self.bi_data:
            ...
        else:
            fwd_pos_seq = torch.arange(beg, end, -1.0)
            if self.clamp_len > 0:
                ...
            pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)

        pos_emb = pos_emb.to(next(self.parameters()))
        return pos_emb

仔細看代碼的話, 我們可以發現, 這次生成的位置的長度, 實際上是: mem_len + input_len + input_len, 和以前的transformer-XL相比較的話, 這裏增加的一個 input_len長度的位置, 其實這裏增加的就是雙向的部分, 因爲原來的單向transformer-XL只有一個mem_len + input_len的長度, 這裏增加的input_len長度就是反向的位置信息. 然後剩下的關於位置信息的計算, 比較不一樣的應該就是之後的截取部分了, 具體的代碼位置是在類XLNetRelativeAttention中的rel_attn_core函數和``函數中:

    def rel_attn_core(self, q_head, k_head_h, v_head_h, k_head_r, seg_mat=None, attn_mask=None, head_mask=None):
        """Core relative positional attention operations."""

        ...

        # position based attention score
        bd = torch.einsum('ibnd,jbnd->ijbn', q_head + self.r_r_bias, k_head_r)
        bd = self.rel_shift(bd, klen=ac.shape[1])

        ...

        return attn_vec
    
        @staticmethod
    def rel_shift(x, klen=-1):
        """perform relative shift to form the relative attention score."""
        x_size = x.shape

        x = x.reshape(x_size[1], x_size[0], x_size[2], x_size[3])
        x = x[1:, ...]
        x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3])
        # x = x[:, 0:klen, :, :]
        x = torch.index_select(x, 1, torch.arange(klen, device=x.device, dtype=torch.long))

        return x

怎麼說呢, 這裏沒有進行pad操作, 直接就截取了, 下面的示意圖大概可以說明這個過程:

在這裏插入圖片描述
圖3-6

圖3-6基本就是XLNet中的雙向transformer-XL的位置信息生成的過程, 怎麼說呢? 這裏假設是輸入了3個單詞, 然後記憶模塊可以記憶2個單詞, 那麼XLNet中生成的input_len + mem_len + input_len位置中, 前面的淺藍色的是從右到左的位置信息, 中間的是來自上次的記憶信息, 右邊的深藍色是從左到右的位置信息(其實嚴格的來說, 並沒有左右的區分, 只是爲了實現雙向, 所以就弄了兩個, 反正論文裏面沒寫, 我就這樣比喻一下). 然後經過這一系列的操作, 我們可以看到最終結果的第一行中的位置信息裏面包含所有的mem部分的位置以及剩下的所有的輸入字符的位置, 剩下的每一行都是2個綠色的加3個藍色的構成的, 當然有淺藍和深藍, 但是確實都包含了所有的字符的位置信息, 這裏我也比較迷, 最後搞成了這樣的結果, 但是模型事實的運行結果證明這樣是可行的.(或許是我理解錯了)

4.於Bert的對比

  作者認爲,Bert無法學習到被mask掉部分的詞之間的信息,作者舉例,對於句子“New York is a city”,預測的目標是“New York”,那麼Bert和XLNet的優化目標分別是:

ξBert=log p(Newisacity)+log p(Yorkisacity)ξXLNet=log p(Newisacity)+log p(YorkNew,isacity) \begin{aligned} & \xi_{Bert} = log \ p(New | is a city) + log \ p(York | is a city) \\ & \xi_{XLNet} = log \ p(New | is a city) + log \ p(York | New, is a city) \end{aligned}

  根據優化目標可以看到,XLNet在預測出來New,會在預測York的時候把New加入到先決條件中。這樣,被mask掉的詞也可以學習它們之間的關係。

5.實驗對比

5.1 長文檔閱讀理解

  RACE數據集是一個針對中國中學生和高中生的英語考試的數據集,數據集包含近10萬個問題,是目前最難的閱讀理解數據集,且數據集中的段落的平均長度在300個單詞之上,比一般的閱讀理解的數據集的長度都長的多。

  XLNet對於Bert提升了大概接近10%左右,根據後面的一些實驗分析,這裏的提升除了加入了PLM,更多的可能是因爲使用了transformer-XL。

在這裏插入圖片描述
圖 5-1

  SQuAD數據集和RACE類似,都是長文檔級別的閱讀理解數據集,這裏的效果提升針對Bert而言提升也比較明顯。

在這裏插入圖片描述
圖 5-2

5.2 消融實驗

  參考https://zhuanlan.zhihu.com/p/70257427,論文中使用和Bert相同的訓練量,訓練了一個XLNet-base模型用於和Bert進行對比。

在這裏插入圖片描述
圖5-3

  首先看DAE+transformer-XL的實驗結果,這裏相當於Bert中的transformer替換成了transformer-XL,是研究長文檔因素造成的影響。RACE和SQuAD2.0都是長文檔的閱讀理解,分數提升1和3個點,但是MNLI和SST-2都是句對分類任務,提升就不明顯了。說明transformer-XL帶來了長文檔的效果提升。

  之後再參考XLNet-Base的效果,這裏體現的是PLM帶來的提升,可以看到四個數據集都有1個點左右的提升,說明PLM是可以給模型帶來收益的。

  除此之外,根據網上別人的分析,根據前面XLNet-large的得分情況,再對比消融研究中的XLNet-Base的得分情況,可以大概得出訓練數據量的提升(接近10倍Bert訓練量)給XLNet的模型在長文本閱讀理解上的提升佔到30%左右。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章