基於DANN的圖像分類任務遷移學習

注:本博客的數據和任務來自NTU-ML2020作業,Kaggle網址爲Kaggle.

數據預處理

我們要進行遷移學習的對象是10000張32x32x3的有標籤正常照片,共有10類,和另外100000張人類畫的手繪圖,28x28x1黑白照片,類別也是10類但無標籤。我們希望做到,讓模型從有標籤的原始分佈數據中學到的知識能應用於無標籤的,相似但與原始分佈不相同的目標分佈中,並提高黑白手繪圖的正確率。
爲此,訓練前還要對數據做預處理。首先讓原始分佈的圖像和目標分佈的圖像儘可能相似,我們要做有色圖轉灰度圖,然後做邊緣檢測。爲了模型的輸入維度相同,要把28x28轉爲32x32.此外還可以增加一些平移旋轉來讓學習更魯棒。

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import cv2
import matplotlib.pyplot as plt

# 在transform中使用轉灰度-canny邊緣提取-水平移動-小幅度旋轉-轉張量操作

source_transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Lambda(lambda x: cv2.Canny(np.array(x), 170, 300)),
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15, fill=(0,)),
    transforms.ToTensor(),
])
target_transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((32, 32)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15, fill=(0,)),
    transforms.ToTensor(),
])

# 讀取數據集,分爲source和target兩部分

source_dataset = ImageFolder('E:/real_or_drawing/train_data', transform=source_transform)
target_dataset = ImageFolder('E:/real_or_drawing/test_data', transform=target_transform)

source_dataloader = DataLoader(source_dataset, batch_size=32, shuffle=True)
target_dataloader = DataLoader(target_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(target_dataset, batch_size=128, shuffle=False)

DANN

Domain-Adversarial Training of NNs,值域對抗學習。這種算法是我們這裏將要用的遷移學習方法,它被提出的起因是讓CNN能夠同時用於不同分佈的數據,如果模型直接接收原值域的數據分佈進行訓練,即使原分佈和目標分佈有類似的地方,在接收目標值域的數據時,也會出現相當異常的特徵提取和分類結果。我們可以理解爲是模型在源數據分佈上出現了過擬合(並不是對數據的過擬合),在接收一些沒有見到過的數據時自然會表現不佳。
在這裏插入圖片描述
解決這個問題最好的辦法就是讓模型在訓練時也接收目標數據分佈的數據。但是目標數據分佈是無標籤的,我們要用什麼標準來訓練模型呢?回憶CNN的架構,CNN使用卷積-池化的特徵提取層來提取圖片特徵,後接全連接層進行預測。我們只需要讓特徵提取層既能提取原數據分佈的特徵,又能提取目標數據分佈的特徵,這樣全連接層就能對兩種值域但具有相同特徵的數據進行同樣的分類,從而目標數據分佈的輸入也很有可能被正確分類。
在這裏插入圖片描述
那麼問題就變成了如何訓練輸入兩個不同分佈的數據,輸出卻是同種分佈的特徵提取層。回憶GAN的架構,我們讓分佈朝着源數據分佈發展的方法是建立判別器,讓判別器能分辨兩種數據,而讓生成器改變參數騙過判別器。這裏也可以用同樣的思想,我們建立能分辨原始分佈和目標分佈的二分類判別器,把特徵提取層和二分類判別層接在一起。首先訓練判別器,讓判別器能分辨兩類數據分佈。然後訓練特徵提取層,逆梯度更新讓特徵提取層生成能騙過判別器的數據(目標輸出0.5).如此訓練多次直到特徵提取層能把兩種值域的輸入變成同種分佈的輸出。
在這裏插入圖片描述
但是隻是用GAN方法train特徵提取層並不明智,因爲我們的目標輸出只有0-1的二分類,訓練很有可能只是讓特徵提取層提取到一些沒有用的特徵。因此我們要一邊訓練正常的標籤預測任務,一邊訓練判別器的判別任務和混淆兩類輸入的任務。這可能需要自己定義特殊的loss function

最後,我們就獲得了能同時提取兩個值域的特徵的特徵提取層,它後面的多分類層就可以對目標分佈的數據做出還算稱心如意的預測。

模型、訓練、測試代碼

這裏使用類VGG(用多個3x3的卷積核代替大型卷積核以節約參數)的搭建方式,寫一個高度卷積的特徵提取層

class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(256, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
    def forward(self, x):
        x = self.conv(x).squeeze()
        return x

#值域分類器,即GAN中的discriminator
class DomainClassifier(nn.Module):
    def __init__(self):
        super(DomainClassifier, self).__init__()

        self.layer = nn.Sequential(
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512, 1),
        )

    def forward(self, h):
        y = self.layer(h)
        return y

#標籤預測器,對特徵作進一步分類
class LabelPredictor(nn.Module):

    def __init__(self):
        super(LabelPredictor, self).__init__()

        self.layer = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),

            nn.Linear(512, 512),
            nn.ReLU(),

            nn.Linear(512, 10),
        )

    def forward(self, h):
        c = self.layer(h)
        return c
feature_extractor = FeatureExtractor().cuda()
label_predictor = LabelPredictor().cuda()
domain_classifier = DomainClassifier().cuda()

# 多分類使用交叉熵損失進行訓練
class_criterion = nn.CrossEntropyLoss()
# domain_classifier的輸出是1維,要先sigmoid轉概率再計算交叉熵,使用BCEWithlogits
domain_criterion = nn.BCEWithLogitsLoss()

# 使用adam訓練
optimizer_F = optim.Adam(feature_extractor.parameters())
optimizer_C = optim.Adam(label_predictor.parameters())
optimizer_D = optim.Adam(domain_classifier.parameters())

我們訓練200個epoch,讓數據儘量收斂

def train_epoch(source_dataloader, target_dataloader, lamb):
    '''
      Args:
        source_dataloader: source data的dataloader
        target_dataloader: target data的dataloader
        lamb: 對抗的lamb係數
    '''

    # D loss: Domain Classifier的loss
    # F loss: Feature Extrator & Label Predictor的loss
    # total_hit: 計算目前對了幾筆 total_num: 目前經過了幾筆
    running_D_loss, running_F_loss = 0.0, 0.0
    total_hit, total_num = 0.0, 0.0

    for i, ((source_data, source_label), (target_data, _)) in enumerate(zip(source_dataloader, target_dataloader)):

        source_data = source_data.cuda()
        source_label = source_label.cuda()
        target_data = target_data.cuda()
        
        # 把source data和target data混在一起,否則batch_norm會出錯
        mixed_data = torch.cat([source_data, target_data], dim=0)
        # 設置判別器的目標標籤
        domain_label = torch.zeros([source_data.shape[0] + target_data.shape[0], 1]).cuda()
        domain_label[:source_data.shape[0]] = 1

        # Step 1 : 訓練Domain Classifier
        feature = feature_extractor(mixed_data)
        # 這裏detach feature,因爲不需要更新extractor的參數
        domain_logits = domain_classifier(feature.detach())
        loss = domain_criterion(domain_logits, domain_label)
        running_D_loss+= loss.item()
        loss.backward()
        optimizer_D.step()

        # Step 2 : 訓練Feature Extractor和Domain Classifier
        class_logits = label_predictor(feature[:source_data.shape[0]])
        domain_logits = domain_classifier(feature)
        # 這裏使用的loss是原值域數據的任務分類交叉熵損失減去,原值域數據和目標值域數據的判別損失
        # 因爲我們想讓extractor騙過判別器,判別損失加負號,而且爲了調控訓練使用lambda作爲係數
        loss = class_criterion(class_logits, source_label) - lamb * domain_criterion(domain_logits, domain_label)
        running_F_loss+= loss.item()
        loss.backward()
        optimizer_F.step()
        optimizer_C.step()

        optimizer_D.zero_grad()
        optimizer_F.zero_grad()
        optimizer_C.zero_grad()

        total_hit += torch.sum(torch.argmax(class_logits, dim=1) == source_label).item()
        total_num += source_data.shape[0]
        print(i, end='\r')

    return running_D_loss / (i+1), running_F_loss / (i+1), total_hit / total_num


# 訓練50 epochs
for epoch in range(50):
    train_D_loss, train_F_loss, train_acc = train_epoch(source_dataloader, target_dataloader, lamb=0.1)

    torch.save(feature_extractor.state_dict(), f'extractor_model.bin')
    torch.save(domain_classifier.state_dict(), f'domain_model.bin')
    torch.save(label_predictor.state_dict(), f'predictor_model.bin')

    print('epoch {:>3d}: train D loss: {:6.4f}, train F loss: {:6.4f}, acc {:6.4f}'.format(epoch, train_D_loss, train_F_loss, train_acc))

訓練好之後可以看見原值域上的訓練集正確率有98以上,想看手繪圖片的正確率可以在Kaggle上提交一下。我們這裏隨便打印一些手繪圖片和模型預測的標籤。

feature_extractor.load_state_dict(torch.load('extractor_model.bin'))
domain_classifier.load_state_dict(torch.load('domain_model.bin'))
label_predictor.load_state_dict(torch.load('predictor_model.bin'))

for i, (data, _) in enumerate(test_dataloader):
    break
    
class_logits = label_predictor(feature_extractor(data.cuda()))

#我們看50張手繪圖的預測

def no_axis_show(img, title='', cmap=None):
    fig = plt.imshow(img, interpolation='nearest', cmap=cmap)
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.title(title)

titles = ['horse', 'bed', 'clock', 'apple', 'cat', 'plane', 'television', 'dog', 'dolphin', 'spider']
plt.figure(figsize=(18, 18))
data = data.cuda()
for i in range(50):
    plt.subplot(5, 10, i+1)
    label = torch.argmax(class_logits[i]).cpu().detach().numpy()
    img = data[i].cpu().detach().numpy().reshape(32,32)
    fig = no_axis_show(img, title=titles[label])

在這裏插入圖片描述
正確率不能說有多高,但是模型似乎學會了分辨一些特徵比較明顯的圖片。

值域

把特徵提取層得到的特徵用PCA降維可以在2D平面上看到值域的分佈。
在這裏插入圖片描述
在不使用DANN時,原值域和目標值域是分開的,這樣的特徵投入全連接層必然不work。但是當我們強制讓模型把兩種數據的特徵混在一起,就變成右圖,這時目標值域的特徵有機會被正確分類。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章