論文代碼:https://github.com/HongguangZhang/DMPHN-cvpr19-master
論文地址:https://arxiv.org/pdf/1904.03468.pdf
論文解讀:https://blog.csdn.net/weixin_42784951/article/details/106108196
文章使用 1-2-4-8 的結構模式。其中1-2-4-8 代表由粗到細網絡所使用的圖像塊。
網絡的每個層都由一對編碼器/解碼器組成。 通過將模糊圖像輸入B1分成多個不重疊的圖像塊塊來生成每個級別的輸入。 較低級別(對應於更精細的網格)的編碼器和解碼器的輸出將被添加到較高級別(高於一個級別),以便頂層包含在較精細級別中推斷出的所有信息。 請注意,每個級別的輸入和輸出圖像塊的數量是不同的,因爲我們工作的主要思想是使較低級別的注意力集中在局部信息(更細的網格)上,從而爲較粗的網格提供殘差信息(通過級聯卷積獲得特徵)。
下面我們從編碼器開始對論文進行解讀:
如文章所述,編碼器由簡單卷積和激活函數組成
self.layer1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.layer2 = nn.Sequential(
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=3, padding=1)
)
self.layer3 = nn.Sequential(
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=3, padding=1)
)
#Conv2
self.layer5 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
self.layer6 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1)
)
self.layer7 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1)
)
#Conv3
self.layer9 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
self.layer10 = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, padding=1)
)
self.layer11 = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, padding=1)
)
Encoder
如下 :
x = self.layer1(x)
x = self.layer2(x) + x
x = self.layer3(x) + x
#Conv2
x = self.layer5(x)
x = self.layer6(x) + x
x = self.layer7(x) + x
#Conv3
x = self.layer9(x)
x = self.layer10(x) + x
x = self.layer11(x) + x
編碼器代碼如上,解碼器代碼與編碼器類似,只不過是將卷積變爲去卷積。
下面是網絡初始化:
print("init data folders")
#編碼器層賦值
encoder_lv1 = models.Encoder()
encoder_lv2 = models.Encoder()
encoder_lv3 = models.Encoder()
#解碼器層賦值
decoder_lv1 = models.Decoder()
decoder_lv2 = models.Decoder()
decoder_lv3 = models.Decoder()
#編碼器層權重初始化
encoder_lv1.apply(weight_init).cuda(GPU)
encoder_lv2.apply(weight_init).cuda(GPU)
encoder_lv3.apply(weight_init).cuda(GPU)
#解碼器層權重初始化
decoder_lv1.apply(weight_init).cuda(GPU)
decoder_lv2.apply(weight_init).cuda(GPU)
decoder_lv3.apply(weight_init).cuda(GPU)
#對參數進行優化
encoder_lv1_optim = torch.optim.Adam(encoder_lv1.parameters(),lr=LEARNING_RATE)
encoder_lv1_scheduler = StepLR(encoder_lv1_optim,step_size=1000,gamma=0.1)
encoder_lv2_optim = torch.optim.Adam(encoder_lv2.parameters(),lr=LEARNING_RATE)
encoder_lv2_scheduler = StepLR(encoder_lv2_optim,step_size=1000,gamma=0.1)
encoder_lv3_optim = torch.optim.Adam(encoder_lv3.parameters(),lr=LEARNING_RATE)
encoder_lv3_scheduler = StepLR(encoder_lv3_optim,step_size=1000,gamma=0.1)
decoder_lv1_optim = torch.optim.Adam(decoder_lv1.parameters(),lr=LEARNING_RATE)
decoder_lv1_scheduler = StepLR(decoder_lv1_optim,step_size=1000,gamma=0.1)
decoder_lv2_optim = torch.optim.Adam(decoder_lv2.parameters(),lr=LEARNING_RATE)
decoder_lv2_scheduler = StepLR(decoder_lv2_optim,step_size=1000,gamma=0.1)
decoder_lv3_optim = torch.optim.Adam(decoder_lv3.parameters(),lr=LEARNING_RATE)
decoder_lv3_scheduler = StepLR(decoder_lv3_optim,step_size=1000,gamma=0.1)
#判斷訓練好的權重是否存在
if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")):
encoder_lv1.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")))
print("load encoder_lv1 success")
if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")):
encoder_lv2.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")))
print("load encoder_lv2 success")
if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")):
encoder_lv3.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")))
print("load encoder_lv3 success")
if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")):
decoder_lv1.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")))
print("load encoder_lv1 success")
if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")):
decoder_lv2.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")))
print("load decoder_lv2 success")
if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")):
decoder_lv3.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")))
print("load decoder_lv3 success")
if os.path.exists('./checkpoints/' + METHOD) == False:
os.system('mkdir ./checkpoints/' + METHOD)
接下來就是迭代訓練的過程,如上圖所示,本文爲 1-2-4
三個尺度,在三個尺度中分別將輸入的模糊圖像分成1、2、4部分,再分別送入三個尺度的編碼器解碼器網絡,代碼如下:
for iteration, images in enumerate(train_dataloader):
#損失函數初始化
mse = nn.MSELoss().cuda(GPU)
#對圖像及圖像尺寸進行初始化
gt = Variable(images['sharp_image'] - 0.5).cuda(GPU)
H = gt.size(2)
W = gt.size(3)
#第一尺度圖像輸入---輸入全部圖像
images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU)
#第二尺度圖像輸入---按照高度輸入兩部分
images_lv2_1 = images_lv1[:,:,0:int(H/2),:]
images_lv2_2 = images_lv1[:,:,int(H/2):H,:]
#第三尺度圖像輸入---按照寬度對第二尺度的兩個部分再進行分割變爲四部分
images_lv3_1 = images_lv2_1[:,:,:,0:int(W/2)]
images_lv3_2 = images_lv2_1[:,:,:,int(W/2):W]
images_lv3_3 = images_lv2_2[:,:,:,0:int(W/2)]
images_lv3_4 = images_lv2_2[:,:,:,int(W/2):W]
#由於文章是由細到粗,先輸入第四尺度,將輸入圖像四個部分分別輸入第四層編碼器網絡
feature_lv3_1 = encoder_lv3(images_lv3_1)
feature_lv3_2 = encoder_lv3(images_lv3_2)
feature_lv3_3 = encoder_lv3(images_lv3_3)
feature_lv3_4 = encoder_lv3(images_lv3_4)
#將第三個尺度的進行合併後送入解碼器
feature_lv3_top = torch.cat((feature_lv3_1, feature_lv3_2), 3)
feature_lv3_bot = torch.cat((feature_lv3_3, feature_lv3_4), 3)
feature_lv3 = torch.cat((feature_lv3_top, feature_lv3_bot), 2)
residual_lv3_top = decoder_lv3(feature_lv3_top)
residual_lv3_bot = decoder_lv3(feature_lv3_bot)
#第二個尺度的輸入爲第三個尺度輸出與原始圖像分割成的兩個部分合並之後再送入網絡
feature_lv2_1 = encoder_lv2(images_lv2_1 + residual_lv3_top)
feature_lv2_2 = encoder_lv2(images_lv2_2 + residual_lv3_bot)
feature_lv2 = torch.cat((feature_lv2_1, feature_lv2_2), 2) + feature_lv3
residual_lv2 = decoder_lv2(feature_lv2)
#第一個尺度的輸入爲第二個尺度輸出合併與原始圖像合併再送入網絡
feature_lv1 = encoder_lv1(images_lv1 + residual_lv2) + feature_lv2
deblur_image = decoder_lv1(feature_lv1)
#損失函數
loss_lv1 = mse(deblur_image, gt)
loss = loss_lv1
#參數優化調整
encoder_lv1.zero_grad()
encoder_lv2.zero_grad()
encoder_lv3.zero_grad()
decoder_lv1.zero_grad()
decoder_lv2.zero_grad()
decoder_lv3.zero_grad()
loss.backward()
encoder_lv1_optim.step()
encoder_lv2_optim.step()
encoder_lv3_optim.step()
decoder_lv1_optim.step()
decoder_lv2_optim.step()
decoder_lv3_optim.step()