EDVR中temporal attention的torch代碼實現

 

self.tAtt_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.tAtt_2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)

B, N, C, H, W = aligned_fea.size()  # N video frames
#### temporal attention
emb_ref = self.tAtt_2(aligned_fea[:, self.center, :, :, :].clone())  # [B, C, H, W]
## embedding
emb = self.tAtt_1(aligned_fea.view(-1, C, H, W)).view(B, N, -1, H, W)  # [B*N, C, H, W]->[B, N, C(nf), H, W]

cor_l = []
for i in range(N):
    emb_nbr = emb[:, i, :, :, :]  # [B, C, H, W]
    ## Dot product
    cor_tmp = torch.sum(emb_nbr * emb_ref, 1).unsqueeze(1)  # [B, H, W] -> [B, 1, H, W]  
    cor_l.append(cor_tmp)
## sigmoid
cor_prob = torch.sigmoid(torch.cat(cor_l, dim=1))  # [B, N, H, W]  
cor_prob = cor_prob.unsqueeze(2).repeat(1, 1, C, 1, 1).view(B, -1, H, W) # [B, N, 1, H, W]->[B, N, C, H, W]->[B, N*C, H, W]
## element-wise multiplication
aligned_fea = aligned_fea.view(B, -1, H, W) * cor_prob  #[B, N*C, H, W]

# fusion conv: using 1x1 to save parameters and computation
self.fea_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
#### fusion
fea = self.lrelu(self.fea_fusion(aligned_fea))
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章