SA-GAN: self-attention 的 pytorch 實現(針對圖像)

問題

基於條件的卷積GAN 在那些約束較少的類別中生成的圖片較好,比如大海,天空等;但是在那些細密紋理,全局結構較強的類別中生成的圖片不是很好,如人臉(可能五官不對應),狗(可能狗腿數量有差,或者毛色不協調)。

可能的原因

大部分卷積神經網絡都嚴重依賴於局部感受野,而無法捕捉全局特徵。另外,在多次卷積之後,細密的紋理特徵逐漸消失。

SA-GAN解決思路

不僅僅依賴於局部特徵,也利用全局特徵,通過將不同位置的特徵圖結合起來(轉置就可以結合不同位置的特徵)。

##############################
# self attention layer
# author Xu Mingle
# time Feb 18, 2019
##############################
import torch.nn.Module
import torch
import torch.nn.init
def init_conv(conv, glu=True):
    init.xavier_uniform_(conv.weight)
    if conv.bias is not None:
        conv.bias.data.zero_()

class SelfAttention(nn.Module):
    r"""
        Self attention Layer.
        Source paper: https://arxiv.org/abs/1805.08318
    """
    def __init__(self, in_dim, activation=F.relu):
        super(SelfAttention, self).__init__()
        self.chanel_in = in_dim
        self.activation = activation
        
        self.f = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8 , kernel_size=1)
        self.g = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8 , kernel_size=1)
        self.h = nn.Conv2d(in_channels=in_dim, out_channels=in_dim , kernel_size=1)
        
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1)

        init_conv(self.f)
        init_conv(self.g)
        init_conv(self.h)
        
    def forward(self, x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention feature maps
                
        """
        m_batchsize, C, width, height = x.size()
        
        f = self.f(x).view(m_batchsize, -1, width * height) # B * (C//8) * (W * H)
        g = self.g(x).view(m_batchsize, -1, width * height) # B * (C//8) * (W * H)
        h = self.h(x).view(m_batchsize, -1, width * height) # B * C * (W * H)
        
        attention = torch.bmm(f.permute(0, 2, 1), g) # B * (W * H) * (W * H)
        attention = self.softmax(attention)
        
        self_attetion = torch.bmm(h, attention) # B * C * (W * H)
        self_attetion = self_attetion.view(m_batchsize, C, width, height) # B * C * W * H
        
        out = self.gamma * self_attetion + x
        return out
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章