PyTorch實現U-Net

先上一張圖,網絡就是實現下面這個圖。

import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable

class Unet(nn.Module):

    def contracting_block(self, in_channels, out_channels, kernel_size=3):
        block = nn.Sequential(
            nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),  # 數據歸一化
            nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),  # 數據歸一化
        )
        return block

    def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
        block = torch.nn.Sequential(
            torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(mid_channel),
            torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(mid_channel),
            # 反捲積,輸出的尺寸爲(inputsize-1)*stride-2padding+k+output_padding和卷積公式算算就出來了
            torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2,
                                     padding=1, output_padding=1),
        )
        return block

    def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
        block = torch.nn.Sequential(
            torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(mid_channel),
            torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(mid_channel),
            torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(out_channels),
        )
        return block

    def __init__(self, in_channel, out_channel):
        super(Unet, self).__init__()
        self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64)
        self.conv_maxpool1 = nn.MaxPool2d(kernel_size=2)

        self.conv_encode2 = self.contracting_block(64, 128)
        self.conv_maxpool2 = nn.MaxPool2d(kernel_size=2)

        self.conv_encode3 = self.contracting_block(128, 256)
        self.conv_maxpool3 = nn.MaxPool2d(kernel_size=2)

        self.conv_encode4 = self.contracting_block(256, 512)
        self.conv_maxpool4 = nn.MaxPool2d(kernel_size=2)


        # 圖中最底下一層
        self.bottleneck = nn.Sequential(
            nn.Conv2d(kernel_size=3, in_channels=512, out_channels=1024),
            nn.ReLU(),
            nn.BatchNorm2d(1024),  # 數據歸一化
            nn.Conv2d(kernel_size=3, in_channels=1024, out_channels=1024),
            nn.ReLU(),
            nn.BatchNorm2d(1024),  # 數據歸一化
            nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2,
                                     padding=1, output_padding=1)
        )

        self.conv_decode4 = self.expansive_block(1024, 512, 256)
        self.conv_decode3 = self.expansive_block(512, 256, 128)
        self.conv_decode2 = self.expansive_block(256, 128, 64)

        self.final_layer = self.final_block(128, 64, out_channel)

    def crop_and_concat(self, upsampled, bypass, crop=False):
        '''
        拼接,聯繫在一起
        :param upsampled:
        :param bypass:
        :param crop:
        :return:
        '''
        if crop:
            c = (bypass.size()[2] - upsampled.size()[2]) // 2
            bypass = F.pad(bypass, (-c, -c, -c, -c))
        return torch.cat((upsampled, bypass), 1)

    def forward(self, x):
        # Encode
        encode_block1 = self.conv_encode1(x)
        encode_pool1 = self.conv_maxpool1(encode_block1)
        encode_block2 = self.conv_encode2(encode_pool1)
        encode_pool2 = self.conv_maxpool2(encode_block2)
        encode_block3 = self.conv_encode3(encode_pool2)
        encode_pool3 = self.conv_maxpool3(encode_block3)
        encode_block4 = self.conv_encode4(encode_pool3)
        encode_pool4 = self.conv_maxpool4(encode_block4)
        # Bottleneck
        bottleneck1 = self.bottleneck(encode_pool4)

        decode_block4 = self.crop_and_concat(bottleneck1, encode_block4, crop=True)
        cat_layer4 = self.conv_decode4(decode_block4)

        decode_block3 = self.crop_and_concat(cat_layer4, encode_block3, crop=True)
        cat_layer3 = self.conv_decode3(decode_block3)

        decode_block2 = self.crop_and_concat(cat_layer3, encode_block2, crop=True)
        cat_layer2 = self.conv_decode2(decode_block2)

        decode_block1 = self.crop_and_concat(cat_layer2, encode_block1, crop=True)
        final_layer = self.final_layer(decode_block1)

        return final_layer


unet = Unet(in_channel=1, out_channel=2)
inputs = Variable(torch.zeros(2, 1, 572, 572))
outputs = unet(inputs)


 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章