注意力模型詳細解析

前言

這裏學習的注意力模型是我在研究image caption過程中的出來的經驗總結,其實這個注意力模型理解起來並不難,但是國內的博文寫的都很不詳細或說很不明確,我在看了 attention-mechanism後才完全明白。得以進行後續工作。

這裏的注意力模型是論文 Show,Attend and Tell:Neural Image Caption Generation with Visual Attention裏設計的,但是注意力模型在大體上來講都是相通的。

先給大家介紹一下我需要注意力模型的背景。


這裏寫圖片描述

I是圖片信息矩陣也就是[224,224,3],通過前面的cnn也就是所謂的sequence-sequence模型中的encoder,我用的是vgg19,得到a,這裏的a其實是[14*14,512]=[196,512],很形象吧,代表的是圖片被分成了這麼多個區域,後面就看我們單詞注意在哪個區域了,大家可以先這麼泛泛理解。通過了本文要講的Attention之後得到z。這個z是一個區域概率,也就是當前的單詞在哪個圖像區域的概率最大。然後z組合單詞的embedding去訓練。

好了,先這麼大概理解一下這張圖就好。下面我們來詳細解剖attention,附有代碼~

attention的內部結構是什麼?



這裏寫圖片描述

這裏的c其實一個隱含輸入,計算方式如下

首先我們這麼個函數:

def _get_initial_lstm(self, features):
    with tf.variable_scope('initial_lstm'):
        features_mean = tf.reduce_mean(features, 1)

        w_h = tf.get_variable('w_h', [self.D, self.H], initializer=self.weight_initializer)
        b_h = tf.get_variable('b_h', [self.H], initializer=self.const_initializer)
        h = tf.nn.tanh(tf.matmul(features_mean, w_h) + b_h)

        w_c = tf.get_variable('w_c', [self.D, self.H], initializer=self.weight_initializer)
        b_c = tf.get_variable('b_c', [self.H], initializer=self.const_initializer)
        c = tf.nn.tanh(tf.matmul(features_mean, w_c) + b_c)
        return c, h
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

上面的c你可以暫時不用管,是lstm中的memory state,輸入feature就是通過cnn跑出來的a,我們暫時考慮batch=1,就認爲這個a是一張圖片生成的。所以a的維度是[1,196,512],y向量代表的就是feature。

下面我們打開這個黑盒子來看看裏面到底是在做什麼處理。


Attention模塊

上圖中可以看到

mi=tanh(Wcmc+Wymyi)mi=tanh(Wcmc+Wymyi)

這裏的tanh不能替換成ReLU函數,一旦替換成ReLU函數,因爲有很多負值就會消失,會很影響後面的結果,會造成最後Inference句子時,不管你輸入什麼圖片矩陣的到的句子都是一樣的。不能隨便用激活函數!!!ReLU是能解決梯度消散問題,但是在這裏我們需要負值信息,所以只能用tanh

c和y在輸入到tanh之前要做個全連接,代碼如下。

        w = tf.get_variable('w', [self.H, self.D], initializer=self.weight_initializer)
        b = tf.get_variable('b', [self.D], initializer=self.const_initializer)
        w_att = tf.get_variable('w_att', [self.D, 1], initializer=self.weight_initializer)

        h_att = tf.nn.relu(features_proj + tf.expand_dims(tf.matmul(h, w), 1) + b)    # (N, L, D)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

這裏的features_proj是feature已經做了全連接後的矩陣。並且在上面計算h_att中你可以看到一個矩陣的傳播機制,也就是relu函數裏的加法。features_proj和後面的那個維度是不一樣的。

def _project_features(self, features):
    with tf.variable_scope('project_features'):
        w = tf.get_variable('w', [self.D, self.D], initializer=self.weight_initializer)
        features_flat = tf.reshape(features, [-1, self.D])
        features_proj = tf.matmul(features_flat, w)  
        features_proj = tf.reshape(features_proj, [-1, self.L, self.D])
        return features_proj
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

然後要做softmax了,這裏有個點,因爲上面得到的m的維度是[1,196,512],1是代表batch數量。經過softmax後想要得到的是維度爲[1,196]的矩陣也就是每個區域的注意力權值。所以

out_att = tf.reshape(tf.matmul(tf.reshape(h_att, [-1, self.D]), w_att), [-1, self.L])   # (N, L)
alpha = tf.nn.softmax(out_att) 
  • 1
  • 2

最後計算s就是一個相乘。

context = tf.reduce_sum(features * tf.expand_dims(alpha, 2), 1, name='context')   #(N, D)
  • 1

這裏也是有個傳播的機制,features維度[1,196,512],後面那個維度[1,196,1]。

最後給個完整的注意力模型代碼。

def _attention_layer(self, features, features_proj, h, reuse=False):
    with tf.variable_scope('attention_layer', reuse=reuse):
        w = tf.get_variable('w', [self.H, self.D], initializer=self.weight_initializer)
        b = tf.get_variable('b', [self.D], initializer=self.const_initializer)
        w_att = tf.get_variable('w_att', [self.D, 1], initializer=self.weight_initializer)

        h_att = tf.nn.relu(features_proj + tf.expand_dims(tf.matmul(h, w), 1) + b)    # (N, L, D)
        out_att = tf.reshape(tf.matmul(tf.reshape(h_att, [-1, self.D]), w_att), [-1, self.L])   # (N, L)
        alpha = tf.nn.softmax(out_att)  
        context = tf.reduce_sum(features * tf.expand_dims(alpha, 2), 1, name='context')   #(N, D)
        return context, alpha    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

如果大家想研究整個完整的show-attend-tell模型,可以去看看github鏈接

                    <li class="tool-item tool-active is-like "><a href="javascript:;"><svg class="icon" aria-hidden="true">
                        <use xlink:href="#csdnc-thumbsup"></use>
                    </svg><span class="name">點贊</span>
                    <span class="count">1</span>
                    </a></li>
                    <li class="tool-item tool-active is-collection "><a href="javascript:;" data-report-click="{&quot;mod&quot;:&quot;popu_824&quot;}"><svg class="icon" aria-hidden="true">
                        <use xlink:href="#icon-csdnc-Collection-G"></use>
                    </svg><span class="name">收藏</span></a></li>
                    <li class="tool-item tool-active is-share"><a href="javascript:;"><svg class="icon" aria-hidden="true">
                        <use xlink:href="#icon-csdnc-fenxiang"></use>
                    </svg>分享</a></li>
                    <!--打賞開始-->
                                            <!--打賞結束-->
                                            <li class="tool-item tool-more">
                        <a>
                        <svg t="1575545411852" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="5717" xmlns:xlink="http://www.w3.org/1999/xlink" width="200" height="200"><defs><style type="text/css"></style></defs><path d="M179.176 499.222m-113.245 0a113.245 113.245 0 1 0 226.49 0 113.245 113.245 0 1 0-226.49 0Z" p-id="5718"></path><path d="M509.684 499.222m-113.245 0a113.245 113.245 0 1 0 226.49 0 113.245 113.245 0 1 0-226.49 0Z" p-id="5719"></path><path d="M846.175 499.222m-113.245 0a113.245 113.245 0 1 0 226.49 0 113.245 113.245 0 1 0-226.49 0Z" p-id="5720"></path></svg>
                        </a>
                        <ul class="more-box">
                            <li class="item"><a class="article-report">文章舉報</a></li>
                        </ul>
                    </li>
                                        </ul>
            </div>
                        </div>
        <div class="person-messagebox">
            <div class="left-message"><a href="https://blog.csdn.net/tongfanle3404">
                <img src="https://profile.csdnimg.cn/A/9/6/3_tongfanle3404" class="avatar_pic" username="tongfanle3404">
                                        <img src="https://g.csdnimg.cn/static/user-reg-year/2x/4.png" class="user-years">
                                </a></div>
            <div class="middle-message">
                                    <div class="title"><span class="tit"><a href="https://blog.csdn.net/tongfanle3404" data-report-click="{&quot;mod&quot;:&quot;popu_379&quot;}" target="_blank">tongfanle3404</a></span>
                                        </div>
                <div class="text"><span>發佈了0 篇原創文章</span> · <span>獲贊 1</span> · <span>訪問量 2814</span></div>
            </div>
                            <div class="right-message">
                                        <a href="https://im.csdn.net/im/main.html?userName=tongfanle3404" target="_blank" class="btn btn-sm btn-red-hollow bt-button personal-letter">私信
                    </a>
                                                        <a class="btn btn-sm  bt-button personal-watch" data-report-click="{&quot;mod&quot;:&quot;popu_379&quot;}">關注</a>
                                </div>
                        </div>
                </div>
發佈了77 篇原創文章 · 獲贊 66 · 訪問量 4萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章