介紹
之前計劃寫一篇tensorflow實現的,CSDN上已經有大佬寫了。最近一個月接觸了一下Pytorch,個人認爲Pytorch相較於Tensorflow來說好用很多。本文的內容是我對Unet論文的總結與提煉,需要提醒的是,Unet原文發佈的時候還沒有提出BN(Batch Normalization). 所以在本文中我會增加這一個步驟。
如果想要安裝Python和Pytorch或者獲得進一步的信息可以點擊Python ,Pytorch
在圖像分割這個大問題上,主要有兩個流派:U-shape和dialated Conv。本文介紹的是U-shape網絡中最爲經典的U-Net。隨着骨幹網路的進化,很多相應衍生出來的網絡大多都是對於Unet進行了改進但是本質上的思路還是沒有太多的變化。比如結合DenseNet 和Unet的FCDenseNet, Unet++
Unet
Unet是一個爲醫學圖像分割設計的auto-encoder-decoder結構的網絡。行業裏也把它視作一種FCN(fully connected network)。 它可以分成兩個部分,down(encoder) 和 up(decoder)。down的主要結構可以看成conv後面跟maxpool。 up的主要結構是一個upsample後面跟conv。
Unet的核心思想
想要弄清這個問題首先要感性的理解一下卷積的作用。就拿MINIST數據集訓練數字識別這個簡單的CNN網絡爲例, 它把一個28*28的圖片抽象成一個0-9的向量。卷積可以看成是特徵的提取,它可以提取出輸入的信息的抽象概念。但是Pool和Conv會損失空間信息。其中,空間信息在pool的過程中損失的更爲嚴重。對於圖像分割來說, 空間信息和抽象信息同樣重要。既然每一個次pool的時候會嚴重損失空間信息,也就是說maxpool之間的空間信息多於之後的。於是Unet提出,把down的特徵連接到對應的up上。
Unet的結構
其中灰色箭頭copy and crop
中的copy
就是concatenate
而crop
是爲了讓兩者的長寬一致
左半邊就是down path右半邊 就是up path。我們來分別介紹這兩個部分。
Down Path
圖中input image tile
就是我們輸入的訓練數據。除了第一層是兩個conv,其他層都可以看成是maxpool後面跟兩個conv。在Unet中絕大部分的conv都是兩個conv連用的形式存在的,爲了方便,我們可以先自定義一個double_conv
類。
# 實現double conv
class double_conv(nn.Module):
''' Conv => Batch_Norm => ReLU => Conv2d => Batch_Norm => ReLU
'''
def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
self.conv.apply(self.init_weights)
def forward(self, x):
x = self.conv(x)
return x
@staticmethod
def init_weights(m):
if type(m) == nn.Conv2d:
init.xavier_normal(m.weight)
init.constant(m.bias,0)
下面我們來實現input conv, 它實際上用一個double_conv
也就完成了。
# 實現input conv
class inconv(nn.Module):
''' input conv layer
let input 3 channels image to 64 channels
The oly difference between `inconv` and `down` is maxpool layer
'''
def __init__(self, in_ch, out_ch):
super(inconv, self).__init__()
self.conv = double_conv(in_ch, out_ch)
def forward(self, x):
x = self.conv(x)
return x
接下來我們來實現down
類,它的結構是一個maxpool
接一個double_conv
。
class down(nn.Module):
''' normal down path
MaxPool2d => double_conv
'''
def __init__(self, in_ch, out_ch):
super(down, self).__init__()
self.mpconv = nn.Sequential(
nn.MaxPool2d(2),
double_conv(in_ch, out_ch)
)
def forward(self, x):
x = self.mpconv(x)
return x
Up path
Unet的up path主要的結構是upsampl
加上double_conv
但是也可以使用ConvTranspose2d
代替upsample。下面的代碼給出了兩種選擇。
在up path 中,我們需要將down path 中的特徵合並進來。在up.forward
中crop從而讓兩個特徵一致。
class up(nn.Module):
''' up path
conv_transpose => double_conv
'''
def __init__(self, in_ch, out_ch, Transpose=False):
super(up, self).__init__()
# would be a nice idea if the upsampling could be learned too,
# but my machine do not have enough memory to handle all those weights
if Transpose:
self.up = nn.ConvTranspose2d(in_ch, in_ch//2, 2, stride=2)
else:
# self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(in_ch, in_ch//2, kernel_size=1, padding=0),
nn.ReLU(inplace=True))
self.conv = double_conv(in_ch, out_ch)
self.up.apply(self.init_weights)
def forward(self, x1, x2):
'''
conv output shape = (input_shape - Filter_shape + 2 * padding)/stride + 1
'''
x1 = self.up(x1)
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = nn.functional.pad(x1, (diffX // 2, diffX - diffX//2,
diffY // 2, diffY - diffY//2))
x = torch.cat([x2,x1], dim=1)
x = self.conv(x)
return x
@staticmethod
def init_weights(m):
if type(m) == nn.Conv2d:
init.xavier_normal(m.weight)
init.constant(m.bias,0)
輪子已經造好了,那麼我們來實現Unet讓它跑起來
class Unet(nn.Module):
def __init__(self, in_ch, out_ch, gpu_ids=[]):
super(Unet, self).__init__()
self.loss_stack = 0
self.matrix_iou_stack = 0
self.stack_count = 0
self.display_names = ['loss_stack', 'matrix_iou_stack']
self.gpu_ids = gpu_ids
self.bce_loss = nn.BCELoss()
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if torch.cuda.is_available() else torch.device('cpu')
self.inc = inconv(in_ch, 64)
self.down1 = down(64, 128)
# print(list(self.down1.parameters()))
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.drop3 = nn.Dropout2d(0.5)
self.down4 = down(512, 1024)
self.drop4 = nn.Dropout2d(0.5)
self.up1 = up(1024, 512, False)
self.up2 = up(512, 256, False)
self.up3 = up(256, 128, False)
self.up4 = up(128, 64, False)
self.outc = outconv(64, 1)
self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
# self.optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0005)
def forward(self):
x1 = self.inc(self.x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x4 = self.drop3(x4)
x5 = self.down4(x4)
x5 = self.drop4(x5)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
self.pred_y = nn.functional.sigmoid(x)
def set_input(self, x, y):
self.x = x.to(self.device)
self.y = y.to(self.device)
def optimize_params(self):
self.forward()
self._bce_iou_loss()
_ = self.accu_iou()
self.stack_count += 1
self.zero_grad()
self.loss.backward()
self.optimizer.step()
def accu_iou(self):
# B is the mask pred, A is the malanoma
y_pred = (self.pred_y > 0.5) * 1.0
y_true = (self.y > 0.5) * 1.0
pred_flat = y_pred.view(y_pred.numel())
true_flat = y_true.view(y_true.numel())
intersection = float(torch.sum(pred_flat * true_flat)) + 1e-7
denominator = float(torch.sum(pred_flat + true_flat)) - intersection + 2e-7
self.matrix_iou = intersection/denominator
self.matrix_iou_stack += self.matrix_iou
return self.matrix_iou
def _bce_iou_loss(self):
y_pred = self.pred_y
y_true = self.y
pred_flat = y_pred.view(y_pred.numel())
true_flat = y_true.view(y_true.numel())
intersection = torch.sum(pred_flat * true_flat) + 1e-7
denominator = torch.sum(pred_flat + true_flat) - intersection + 1e-7
iou = torch.div(intersection, denominator)
bce_loss = self.bce_loss(pred_flat, true_flat)
self.loss = bce_loss - iou + 1
self.loss_stack += self.loss
def get_current_losses(self):
errors_ret = {}
for name in self.display_names:
if isinstance(name, str):
errors_ret[name] = float(getattr(self, name)) / self.stack_count
self.loss_stack = 0
self.matrix_iou_stack = 0
self.stack_count = 0
return errors_ret
def eval_iou(self):
with torch.no_grad():
self.forward()
self._bce_iou_loss()
_ = self.accu_iou()
self.stack_count += 1
其他的代碼就是很固定的pytorch模板代碼了。
代碼參考自GitHub
轉載請標明出處