pytorch 卷積計算

import torch
from torch import nn

class CNN(nn.Module):
    def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4):
        super(CNN, self).__init__()
        self.first_part = nn.Sequential(
            nn.Conv2d(num_channels, d, kernel_size=3, padding=5//2),
            # nn.Conv2d(num_channels, d, kernel_size=(1,3), padding=5//2),
            # nn.Conv2d(num_channels, d, kernel_size=(3,1), padding=5//2),
            nn.PReLU(d)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.first_part:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)

    def forward(self, x):
        x1 = self.first_part(x)
        return x

model = CNN()
input = torch.randn(1,1,28,36)
pre = model(input)
print(pre)
print(pre.data)
print(pre.data[0][0][1][1].item())
print(pre.data.size())
print(pre.data.shape)
print(pre.shape)
print(pre.size())

結果:
tensor([[[[ 6.6364e-02, -1.3303e-04, -3.5543e-02,  ..., -6.1688e-02,
            3.1265e-01, -8.2845e-03],
          [-3.5500e-02, -4.9340e-02,  1.7163e-01,  ...,  1.3813e-01,
           -8.2757e-03,  8.5905e-03],
          [ 8.1270e-02, -2.7835e-03, -2.5063e-02,  ...,  1.3616e-01,
           -6.6608e-02, -6.2187e-03],
          ...,
          [ 4.4998e-02, -4.2893e-02,  2.1688e-01,  ...,  3.1650e-01,
           -2.5981e-02,  2.6824e-03],
          [ 4.0486e-03,  1.2759e-02,  3.1908e-01,  ...,  2.2530e-01,
           -4.8846e-02,  3.9362e-02],
          [ 8.3523e-03,  1.5983e-02,  6.9260e-02,  ..., -2.0295e-02,
           -2.0337e-02,  3.9436e-02]]]], grad_fn=<PreluBackward>)
tensor([[[[ 6.6364e-02, -1.3303e-04, -3.5543e-02,  ..., -6.1688e-02,
            3.1265e-01, -8.2845e-03],
          [-3.5500e-02, -4.9340e-02,  1.7163e-01,  ...,  1.3813e-01,
           -8.2757e-03,  8.5905e-03],
          [ 8.1270e-02, -2.7835e-03, -2.5063e-02,  ...,  1.3616e-01,
           -6.6608e-02, -6.2187e-03],
          ...,
          [ 4.4998e-02, -4.2893e-02,  2.1688e-01,  ...,  3.1650e-01,
           -2.5981e-02,  2.6824e-03],
          [ 4.0486e-03,  1.2759e-02,  3.1908e-01,  ...,  2.2530e-01,
           -4.8846e-02,  3.9362e-02],
          [ 8.3523e-03,  1.5983e-02,  6.9260e-02,  ..., -2.0295e-02,
           -2.0337e-02,  3.9436e-02]]]])
-0.04934029281139374
torch.Size([1, 56, 30, 38]) #計算過程30 = (28-3+2x2)/1 + 1, 38 = (36-3+2x2)/1+1
torch.Size([1, 56, 30, 38])
torch.Size([1, 56, 30, 38])
torch.Size([1, 56, 30, 38])

計算公式:

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章