不定期讀一篇Paper之Dual Attention Network
前言
本文提出雙重注意力網絡,使用自注意力機制捕獲了空間和通道的全局依賴,從而實現獲得長距離的依賴,使得表達信息更加的充分。建議讀原文並閱讀源碼理解。
框架
整體結構
DAN整體上把位置自注意力機制和通道注意力機制並行處理,處理方式有一點像CBAM,只不過CBAM使用到的是注意力機制,區別與自注意力機制。位置自注意力能捕獲長距離的信息,從而獲得更多的空間信息,通道子注意力能夠捕獲不同通道之間的依賴,從而對特徵進行重新標定,思路上和SENet有點相似。兩種自注意力最後通過加法融合,從而獲得兼顧位置和通道的側重程度。如下圖:
注:
自注意力機制:自注意力機制會學習輸入特徵圖內部之間的相關性,所以,又稱intra-attention(內部注意力),是特徵圖內部全局信息的表徵。CNN和自注意力層的主要區別是,一個像素的新值依賴於圖像的其他像素。
注意力機制:注意力機制就是通過學習一個權重分佈,再把這個權重分佈施加到原來的特徵上面,以獲取更多的所需要關注目標的細節信息,而抑制其他無用信息,是一種資源分配方式,計算時需要外部信息的接入。
位置自注意力
通過自注意力機制,從而獲得長距離的依賴,即其中一個像素點的新值不在侷限於局部感受野,而是與全局信息相關聯。具體結果如下:
- 圖示
-
流程
下圖示中,query 代表B,key代表C,value代表D
- 代碼
class PostisionAttentiomModule(nn.Module):
"""位置注意模塊"""
def __init__(self, in_channels):
super(PostisionAttentiomModule, self).__init__()
self.query_conv = nn.Conv2d(in_channels, in_channels//8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels, in_channels//8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
# gamma參數初始化爲0,可學習
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
N, C, H, W = x.shape
# N * (H*W) * C'
query = self.query_conv(x).view(N, -1, H*W).permute(0, 2, 1)
# N * C' * (H*W)
key = self.key_conv(x).view(N, -1, H*W)
energy = torch.bmm(query, key)
# N * (H*W) * (H*W)
attention = self.softmax(energy)
# N * C * (H*W)
value = self.value_conv(x).view(N, -1, H*W)
# N * C * (H*W)
out = torch.bmm(value, attention.permute(0, 2, 1)) #注意softmax計算的維度
out = out.view(N, C, H, W)
# fusion
out = self.gamma*out + x
return out
通道自注意力
通過自注意力機制,探索內部通道之間的依賴。
- 圖示
-
流程
下圖示中,query、key和value分代表上述圖中由下到上三個分支。
- 代碼
class ChannelAttentionModule(nn.Module):
"""通道注意力機制"""
def __init__(self):
super(ChannelAttentionModule, self).__init__()
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
N, C, H, W = x.shape
# N*C*(H*W)
query = x.view(N, C, -1)
# N*(H*W)*C
key = x.view(N, C, -1).permute(0, 2, 1)
# N * C * C
energy = torch.bmm(query, key)
energy = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) - energy
attention = self.softmax(energy)
value = x.view(N, C, -1)
out = torch.bmm(attention, value)
out = out.view(N, C, H, W)
out = self.gamma*out + x
return out