music transformer:generating music with long-term structure

音樂生成最近大火,在這篇文章以前RNN被用來做這種長序列生成問題,但是與機器翻譯這種任務不同,音樂是一種比自然語言更加Hierarchial的數據表徵。RNN在長序列生成問題上由於自身結構原因,一個序列中如果兩個信息相隔位置過遠,RNN的遞歸式結構顯然無法捕捉到這種信息,這在機器翻譯任務上已經得到了證實。對於音樂來說,音樂非常強調結構性,古典音樂更是如此。目前來說,從谷歌magenta在2017年使用performanceRNN生成的音樂片段來看,生成的音樂偶爾有一兩秒比較具有音樂性的片段,從整體來看,則是亂彈琴,不具有long term structure。2018年transformer的提出,刷新了機器翻譯人任務在不同數據集上的結果。


本篇文章還會繼續更新
文章來自Google magenta, 論文連接博客連接

論文筆記放在github

1.Motivation

  • 音樂作品在結構和意義上有大量的重複,這種自相關性體現在多重時間尺度上,從主體到樂句上都重複使用整段的音樂章節,比如說在音樂片段中的ABA結構。
  • 使用self-attention機制的transformer模型在機器翻譯這類需要保持sequence很長連貫性的任務上取得巨大成功,這表明或許transformer很適合構建音樂生成和音樂表現,當然,相對的timing對音樂來說也很重要,現有的使用absolute position embedding的transformer模型在直覺上並不make sense,對於relative position embedding的transformer,它的relative position embedding是基於query&key pair-wise position,這對音樂生成任務上並不可行,因爲算法中表徵relative position位置的中間向量空間複雜度上是sequence 長度的平方,文中成功的將這個中間向量的空間複雜度降到了sequence 長度量級。使得這個transformer能夠在給定一個motif生成數分鐘長的音樂,或者基於seq2seq設置,給定一個旋律生成伴奏。

2.Transformer 原理簡介

Transformer的原理介紹現在太多了,在這裏再贅述一遍,transformer中一個層包括一個self-attention層以及一個feedfoward層。

在這裏插入圖片描述

由於transformer中的self-attention機制中沒有明顯地對相對位置進行建模,而是額外增加一種絕對的位置表徵(attention原文中position embedding:使用正弦函數和餘弦函數來構建每個位置的值。在2018NAACL論文Self-Attention with Relative Position Reoresentations中提出了一種考慮相對位置表徵的self-attention機制。

3.Relative positional self-attention

雖然文中輕描淡寫地總結了那篇NAACL的論文改進的帶relative attention positionde的self attention機制。

但是讀到文章中畫橫線的句子,我很懵逼:ErE^r怎麼就是(H,L,DhH,L,D_{h})呢?這個ErE^r和queries是怎麼得到SrelS^{rel}的呢?那最後這個(LLL*L)的RelativeAttention矩陣又是怎麼得到的呢?


同樣中間張量R(L,L,DhL,L,D_{h}), 張量裏面的元素代表什麼?SrelS^{rel}=QRTQR^{T}是矩陣和矩陣怎麼進行相乘的?

要想回答這些問題必須要看這篇NAACL的論文了,還要結合代碼,參考博客:這篇博客對transformer中的相對位置表徵講的很透徹

我這裏只簡單的概括一下:作者提出在transformer中加入一組可以訓練的相對距離表示(RPR),從而是使輸出帶有一定的順序信息。RPR會在計算詞i的輸出表示ziz_{i}, 詞i對詞j的注意力權重係數時用到

一個最簡單的例子:一個句子有五個詞,設k=4,那麼這五個詞每個詞都有9個相對距離表示(一個是和自己的距離,和上文的四個詞的距離,和下文的四個詞的距離),設置詞ii與自己的距離在RPR中對應index4,則詞ii與詞i+1i+1的距離在RPR對應index5,詞ii與詞i1i-1的距離在RPR中對應index3

值得注意的是,論文中提到,詞間距離的最大值限制在一個常數k,這意味着需要學習的RPR的數量書2k+1(上文k個詞,下文k個詞,當前詞),往右間隔超過k的詞對應RPR中第2k個index,往左間隔超過k的詞對應RPR中第0個index,如果一個有10個詞的輸入序列,k設爲3,那麼RPR的查詢表爲:

在這裏插入圖片描述

  • 超過一定距離,再精確的相對位置信息是時沒有用的。
  • 限制最長距離能夠提升模型在對未在訓練階段出現過的長度的序列的泛化能力

在NAACL這篇文章中把兩種注意力機制表示如下

  • 普通的self-attention機制表示:

    Zi=j=1naij(xjWv)(1)Z_i=\sum_{j=1}^na_{ij}(x_jW^v) \tag{1}

    aij=exp  eijk=1nexp  eik(2)a_{ij}=\frac{exp\;e_{ij}}{{\displaystyle\sum_{k=1}^n}exp\;e_{ik}} \tag{2}

    eij=(xiWQ)(xjWK)dk(3)e_{ij}=\frac{\left(x_iW^Q\right)\left(x_jW^K\right)}{\sqrt{d_k}} \tag{3}

  • 考慮相對位置表徵的self-attention表示:

    Zi=j=1naij(xjWv+aijV)(4)Z_i=\sum_{j=1}^na_{ij}(x_jW^v+a_{ij}^V) \tag{4}

    aij=exp  eijk=1nexp  eik(5)a_{ij}=\frac{exp\;e_{ij}}{{\displaystyle\sum_{k=1}^n}exp\;e_{ik}} \tag{5}

    eij=(xiWQ)(xjWK+aijK)dk(6)e_{ij}=\frac{\left(x_iW^Q\right)\left(x_jW^K+a_{ij}^K\right)}{\sqrt{d_k}} \tag{6}

其中aijVa_{ij}^V,aijVa_{ij}^V是兩個relative position representation(RPR)

Transformer的輸入是一個大小爲 (batch_size, seq_length, embedding_dim)的張量。在不帶RPR嵌入的情況下,Transformer能夠利用batch_size * h 並行地進行矩陣乘法來計算 eᵢⱼ (式子2) 。每一次矩陣乘法都會計算給定輸入序列和注意力頭中所有的元素的eᵢⱼ 。這個過程使用下面的表達式實現的

我們首先使用了矩陣乘法的性質將式子(6)重寫爲:

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-UhfO6pgH-1572407589864)(https://cdn-images-1.medium.com/max/1000/1*VB1i8gI67cPHQ7bkmVVk2g.png)]

分子的左半部分和式子 (2)相同,因此在矩陣乘法中能夠高效運算。右半部分就有點技巧性了。這部分代碼實現定義在函數 relative_attention_inner 中,因此我會較簡單地把大體邏輯介紹一下。

  • 分子左半部分的大小爲 (batch_size, h, seq_length, seq_length)。這個張量的行i列j上的元素代表了詞i的query向量和詞j的key向量的點積的結果 。因此,我們的目標是產生另一個和這個張量大小相同的張量,而這個張量的各個元素應該是詞i的query向量和詞i與詞j之間的RPR嵌入的點積的結果(譯者注:也就是分子右半部分)。
  • 首先,我們使用查表的形式爲一個給定的輸入序列生成RPR嵌入張量A,A的形狀是(seq_length, seq_length, dₐ) 【這其實就是music tranformer文章提到的中間向量R, 也就是我加了紅色下劃線的地方】。然後,我們對A進行轉置,使它的形狀變成 (seq_length, dₐ , seq_length) ,寫成 Aᵀ。
  • 接下來,我們計算輸入序列所有元素的query向量,得到一個 (batch_size, h, seq_length, dz)形狀的張量。然後對其進行轉置,形狀變爲 (seq_length, batch_size, h, dz) ,然後變形爲 (seq_length, batch_size * h, dz)的張量。這個張量現在就能與 Aᵀ相乘了。這個乘法可以視爲矩陣 (batch_size * h, dz) 和矩陣 (dₐ, seq_length)的乘法。基本上就是計算每個位置的query向量和對應的RPR向量嵌入的點積這部分解釋了music tranformer文章中Q 是如何reshape成(L,1,DhL, 1, D_{h}), 文章爲了簡要描述將batchsize設爲了1,論文下面的腳註有解釋,然後再將Q和R的轉置進行矩陣相乘】。
  • 上面的乘法得到一個形狀爲 (seq_length, batch_size * h, seq_length)的張量。我們只需要將其變形爲(seq_length, batch_size, h, seq_length)的形狀,然後再轉置得到形狀爲 (batch_size, h, seq_length, seq_length) 的張量,這樣我們就能將它和分子左半部分進行相加了。

4.Efficient implementation of relative position-based attention

算法簡介

不廢話了,本篇論文的一個算法貢獻是,它不需要中間張量R(seq_length, seq_length, dₐ)來計算query QQ和相對位置向量表徵RR的矩陣乘法,正是這個向量帶來了O(L2dL^2d)的空間複雜度。

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-7coNwOeL-1572407213791)(/Users/wanglei/博客文件/博客圖片/paper_segment3.jpg)]

不再採用18年NAACL那篇論文提出的方法,而是將直接將Q和RPR向量(也就是那個查詢表展示的向量,形狀爲seq_length, seq_length)相乘。因爲作者發現通過從QRTQR^{T}中求得的SrelS^{rel}同樣可以從QErTQE^{rT}中通過變換求得,這樣就避免了求中間向量R。

但是QErTQE^{rT}(iq,r)(i_{q},r)代表的是位置爲iqi_{q}的query向量和相對距離爲r的embedding向量的點積,而不是SrelS^{rel}iqi_{q}的query向量和位置jkj_{k}和位置iqi_{q}的相對距離表徵向量的點積。所以接下來對QErTQE^{rT}進行skew(pad,reshape,slice)操作可以得到SrelS^{rel}。對應時間複雜度,同樣都是O(L2D)O(L^{2}D),但是在seq_length=650時,比原來的算法快6倍。

skew操作

  • Pad: 在QErTQE^{rT}的左邊補上一列長爲seq_length的向量
  • Reshape: 按照如下規則reshape矩陣:行的索引保持不變,列的索引計算爲jk=r(L1)+rqj_{k}=r-(L-1)+r_{q}
  • Slice: 最後保留最後L行,列保持不變
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章