self.tAtt_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.tAtt_2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
B, N, C, H, W = aligned_fea.size() # N video frames
#### temporal attention
emb_ref = self.tAtt_2(aligned_fea[:, self.center, :, :, :].clone()) # [B, C, H, W]
## embedding
emb = self.tAtt_1(aligned_fea.view(-1, C, H, W)).view(B, N, -1, H, W) # [B*N, C, H, W]->[B, N, C(nf), H, W]
cor_l = []
for i in range(N):
emb_nbr = emb[:, i, :, :, :] # [B, C, H, W]
## Dot product
cor_tmp = torch.sum(emb_nbr * emb_ref, 1).unsqueeze(1) # [B, H, W] -> [B, 1, H, W]
cor_l.append(cor_tmp)
## sigmoid
cor_prob = torch.sigmoid(torch.cat(cor_l, dim=1)) # [B, N, H, W]
cor_prob = cor_prob.unsqueeze(2).repeat(1, 1, C, 1, 1).view(B, -1, H, W) # [B, N, 1, H, W]->[B, N, C, H, W]->[B, N*C, H, W]
## element-wise multiplication
aligned_fea = aligned_fea.view(B, -1, H, W) * cor_prob #[B, N*C, H, W]
# fusion conv: using 1x1 to save parameters and computation
self.fea_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
#### fusion
fea = self.lrelu(self.fea_fusion(aligned_fea))