pytorch 卷積權重形狀

# -*- coding: utf-8 -*-
import argparse
import os
import copy

import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import numpy as np

from models_ import FSRCNN
from datasets import TrainDataset, EvalDataset
from utils import AverageMeter, calc_psnr

import math


class CNN(nn.Module):
    def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4):
        super(CNN, self).__init__()
        self.first_part = nn.Sequential(
            nn.Conv2d(num_channels, d, kernel_size=3, padding=1),
            # nn.Conv2d(num_channels, d, kernel_size=(1,3), padding=5//2),
            # nn.Conv2d(num_channels, d, kernel_size=(3,1), padding=5//2),
            nn.PReLU(d)
        )
        self.mid_part1 = nn.Sequential(nn.Conv2d(d, s, kernel_size=3, padding=1), nn.PReLU(s))
        self.mid_part2 = nn.Sequential(nn.Conv2d(d+s, s, kernel_size=3, padding=1), nn.PReLU(s))
        self.mid_part3 = nn.Sequential(nn.Conv2d(d+s+s, s, kernel_size=3, padding=1), nn.PReLU(s))
        self.mid_part4 = nn.Sequential(nn.Conv2d(d+s+s+s, s, kernel_size=3, padding=1), nn.PReLU(s))
        self.mid_part = nn.Sequential(nn.Conv2d(d+s+s+s+s, scale_factor ** 2, kernel_size=3, padding=1), nn.PReLU(scale_factor ** 2))
        
        # self.last_part = nn.ConvTranspose2d(d+s+s+s+s, num_channels, kernel_size=3, stride=scale_factor, padding=3//2,
        #                                     output_padding=scale_factor-1)
        self.last_part = nn.PixelShuffle(scale_factor)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.first_part:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        for m in self.mid_part1:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        for m in self.mid_part2:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        for m in self.mid_part3:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        for m in self.mid_part4:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        # nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001)
        # nn.init.zeros_(self.last_part.bias.data)

    def forward(self, x):
        print(x.size())
        out1 = self.first_part(x)
        print(out1.size())
        temp = torch.cat([out1, x], 1)
        print(temp.size())
        out2 = self.mid_part1(out1)
        print(out2.size())
        cat2 = torch.cat([out1, out2], 1)
        print(cat2.size())
        out3 = self.mid_part2(cat2)
        print(out3.size())
        cat3 = torch.cat([out1, out2, out3], 1)
        print(cat3.size())
        out4 = self.mid_part3(cat3)
        print(out4.size())
        cat4 = torch.cat([out1, out2, out3, out4], 1)
        print(cat4.size())
        out5 = self.mid_part4(cat4)
        print(out5.size())
        print(self.mid_part4)
        for m in self.mid_part4:
            if isinstance(m, nn.Conv2d):
                print('weight形狀:',m.weight.data.size())  #卷積的權重大小
                print(m.bias.data)
        cat5 = torch.cat([out1, out2, out3, out4, out5], 1)
        print(cat5.size())
        out6 = self.mid_part(cat5)
        print('out6.size():',out6.size())
        m = self.last_part
        print(m)
        out = self.last_part(out6)
        print(out.size())
        # print(m.weight.data.size())
        # print(m.bias.data)
        return out

if __name__ == '__main__':

    model = CNN(scale_factor = 3)
    print(model)
    input = torch.randn(12,1,28,36)
    with torch.no_grad():
        pre = model(input)
    # print(pre)
    # print(pre.size())
    pred = pre.clamp(0.0, 1.0)
    print('pred.size():',pred.size())
    print(pred[..., 0].shape)
    print(pred[..., 1].shape)
    print(pred[..., 2].shape)
    # pred = pred.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
    # print(pred.shape)
    # print(pred[..., 0].shape)
    params = sum(p.numel() for p in model.parameters()) #計算模型總參數量
    print(params)


結果:
CNN(
  (first_part): Sequential(
    (0): Conv2d(1, 56, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): PReLU(num_parameters=56)
  )
  (mid_part1): Sequential(
    (0): Conv2d(56, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): PReLU(num_parameters=12)
  )
  (mid_part2): Sequential(
    (0): Conv2d(68, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): PReLU(num_parameters=12)
  )
  (mid_part3): Sequential(
    (0): Conv2d(80, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): PReLU(num_parameters=12)
  )
  (mid_part4): Sequential(
    (0): Conv2d(92, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): PReLU(num_parameters=12)
  )
  (mid_part): Sequential(
    (0): Conv2d(104, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): PReLU(num_parameters=9)
  )
  (last_part): PixelShuffle(upscale_factor=3)
)
torch.Size([12, 1, 28, 36])
torch.Size([12, 56, 28, 36])
torch.Size([12, 57, 28, 36])
torch.Size([12, 12, 28, 36])
torch.Size([12, 68, 28, 36])
torch.Size([12, 12, 28, 36])
torch.Size([12, 80, 28, 36])
torch.Size([12, 12, 28, 36])
torch.Size([12, 92, 28, 36])
torch.Size([12, 12, 28, 36])
Sequential(
  (0): Conv2d(92, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): PReLU(num_parameters=12)
)
weight形狀: torch.Size([12, 92, 3, 3])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
torch.Size([12, 104, 28, 36])
out6.size(): torch.Size([12, 9, 28, 36])
PixelShuffle(upscale_factor=3)
torch.Size([12, 1, 84, 108])
pred.size(): torch.Size([12, 1, 84, 108])
torch.Size([12, 1, 84])
torch.Size([12, 1, 84])
torch.Size([12, 1, 84])
41122  #模型總參數量


 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章