24模型微調(finetune)

一、Transfer Learning & Model Finetune

1.1 Transfer Learning

Transfer Learning:機器學習分支,研究源域(source domain)的知識如何應用到目標域(targetdomain)
在這裏插入圖片描述

傳統的機器學習:
對不同的任務分別訓練學習得到不同的learning system,即模型,如上圖有三個不同任務,就得到三個不同的模型

遷移學習:
先對源任務進行學習,得到知識,然後在目標任務中,會使用再源任務上學習得到的知識來學習訓練模型,也就是說該模型不僅用到了target tasks,也用到了source tasks

1.2 Model Finetune

1.2.1 Model Finetune概念

Model Finetune:模型的遷移學習在這裏插入圖片描述
模型微調:
模型微調就是一個遷移學習的過程,模型中訓練學習得到的權值,就是遷移學習中所謂的知識,而這些知識是可以進行遷移的,把這些知識遷移到新任務中,這就完成了遷移學習

微調的原因:
在新任務中,數據量太小,不足以去訓練一個較大的模型,從而選擇Model Finetune去輔助訓練一個較好的模型,使得訓練更快

卷積神經網絡的遷移:
在這裏插入圖片描述
將卷積神經網絡分成兩部分:features extractor + classifier

  • features extractor:模型的共性部分,通常對其進行保留
  • classifier:根據不同任務要求對輸出層進行finetune

1.2.2 Model Finetune步驟

在這裏插入圖片描述
Model Finetune:
先進行模型微調,加載模型參數,並根據任務要求修改模型,此過程稱預訓練,然後進行正式訓練,此時要注意預訓練的參數的保持,具體步驟和方法如下

模型微調步驟:

  1. 獲取預訓練模型參數
  2. 加載模型( load_state_dict)
  3. 修改輸出層

模型微調訓練方法:

  • 固定預訓練的參數,兩種方法:
    • requires_grad =False
    • lr=0
  • Features Extractor部分設置較小學習率( params_group)

說明:
優化器中可以管理不同的參數組,這樣就可以爲不同的參數組設置不同的超參數,對Features Extractor部分設置較小學習率

二、Pytorch中的Finetune

2.1 Model Finetune實例

在這裏插入圖片描述
數據: https://download.pytorch.org/tutorial/hymenoptera_data.zip
模型: https://download.pytorch.org/models/resnet18-5c106cde.pth

2.1.1 目錄結構

在這裏插入圖片描述
模型和數據的存放位置如上圖所示

2.1.1 代碼詳解

my_dataset.py

# -*- coding: utf-8 -*-
import os
import random
from PIL import Image
from torch.utils.data import Dataset

random.seed(1)
rmb_label = {"1": 0, "100": 1}


class AntsDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.label_name = {"ants": 0, "bees": 1}
        self.data_info = self.get_img_info(data_dir)
        self.transform = transform

    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img,label

    def __len__(self):
        return len(self.data_info)

    def get_img_info(self, data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍歷類別
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))

                # 遍歷圖片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = self.label_name[sub_dir]
                    data_info.append((path_img, int(label)))

        if len(data_info) == 0:
            raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(data_dir))
        return data_info

common_tools.py

# -*- coding: utf-8 -*-

import torch
import random
import numpy as np
from PIL import Image
import torchvision.transforms as transforms


def transform_invert(img_, transform_train):
    """
    將data 進行反transfrom操作
    :param img_: tensor
    :param transform_train: torchvision.transforms
    :return: PIL image
    """
    if 'Normalize' in str(transform_train):
        norm_transform = list(filter(lambda x: isinstance(x, transforms.Normalize), transform_train.transforms))
        mean = torch.tensor(norm_transform[0].mean, dtype=img_.dtype, device=img_.device)
        std = torch.tensor(norm_transform[0].std, dtype=img_.dtype, device=img_.device)
        img_.mul_(std[:, None, None]).add_(mean[:, None, None])

    img_ = img_.transpose(0, 2).transpose(0, 1)  # C*H*W --> H*W*C
    if 'ToTensor' in str(transform_train):
        img_ = img_.detach().numpy() * 255

    if img_.shape[2] == 3:
        img_ = Image.fromarray(img_.astype('uint8')).convert('RGB')
    elif img_.shape[2] == 1:
        img_ = Image.fromarray(img_.astype('uint8').squeeze())
    else:
        raise Exception("Invalid img shape, expected 1 or 3 in axis 2, but got {}!".format(img_.shape[2]) )

    return img_


def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

finetune_resnet18.py

# -*- coding: utf-8 -*-

import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from tools.my_dataset import AntsDataset
from tools.common_tools import set_seed
import torchvision.models as models
import torchvision
BASEDIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 設置模型訓練的設備
print("use device :{}".format(device))

set_seed(1)  # 設置隨機種子
label_name = {"ants": 0, "bees": 1}

# 參數設置
MAX_EPOCH = 25
BATCH_SIZE = 16
LR = 0.001
log_interval = 10
val_interval = 1
classes = 2
start_epoch = -1
lr_decay_step = 7


# ============================ step 1/5 數據 ============================
data_dir = os.path.join(BASEDIR, "..", "..", "data/hymenoptera_data")
train_dir = os.path.join(data_dir, "train")
valid_dir = os.path.join(data_dir, "val")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 構建MyDataset實例
train_data = AntsDataset(data_dir=train_dir, transform=train_transform)
valid_data = AntsDataset(data_dir=valid_dir, transform=valid_transform)

# 構建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

# 1/3 構建模型
resnet18_ft = models.resnet18()

# 2/3 加載參數
# flag = 0
flag = 1
if flag:
    path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth")
    state_dict_load = torch.load(path_pretrained_model)
    resnet18_ft.load_state_dict(state_dict_load)

# 法1 : 凍結卷積層
flag_m1 = 0
# flag_m1 = 1
if flag_m1:
    for param in resnet18_ft.parameters():
        param.requires_grad = False
    print("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.conv1.weight[0, 0, ...]))


# 3/3 替換fc層
num_ftrs = resnet18_ft.fc.in_features            # 從原始的resnet18從獲取輸入的結點數
resnet18_ft.fc = nn.Linear(num_ftrs, classes)


resnet18_ft.to(device)        # 將模型遷移到設置的設備上
# ============================ step 3/5 損失函數 ============================
criterion = nn.CrossEntropyLoss()                                                   # 選擇損失函數

# ============================ step 4/5 優化器 ============================
# 法2 : conv 小學習率
flag = 0
# flag = 1
if flag:
    # 劃分模型參數爲兩個部分:resnet18_ft.fc.parameters()和base_params
    fc_params_id = list(map(id, resnet18_ft.fc.parameters()))     # 返回的是parameters的 內存地址
    base_params = filter(lambda p: id(p) not in fc_params_id, resnet18_ft.parameters())

    optimizer = optim.SGD([
        {'params': base_params, 'lr': LR*0.1},   # 0
        {'params': resnet18_ft.fc.parameters(), 'lr': LR}], momentum=0.9)

else:
    optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9)               # 選擇優化器

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1)     # 設置學習率下降策略


# ============================ step 5/5 訓練 ============================
train_curve = list()
valid_curve = list()

for epoch in range(start_epoch + 1, MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    resnet18_ft.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)   # 訓練數據也要放到設置的設備上
        outputs = resnet18_ft(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 統計分類情況
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().cpu().sum().numpy()

        # 打印訓練信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

            # if flag_m1:
            # print("epoch:{} conv1.weights[0, 0, ...] :\n {}".format(epoch, resnet18_ft.conv1.weight[0, 0, ...]))

    scheduler.step()  # 更新學習率

    # validate the model
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        resnet18_ft.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = resnet18_ft(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().cpu().sum().numpy()

                loss_val += loss.item()

            loss_val_mean = loss_val/len(valid_loader)
            valid_curve.append(loss_val_mean)
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_mean, correct_val / total_val))
        resnet18_ft.train()

train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由於valid中記錄的是epochloss,需要對記錄點進行轉換到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

不進行finetune的運行結果:
在這裏插入圖片描述
不進行finetune,最終訓練結果,loss爲0.5729,accuracy爲67.50%

遷移學習加載參數後運行結果:
在這裏插入圖片描述
可以看到,通過遷移學習,加載已學習的參數後進行訓練,準確率是直接從百分之六十多開始增加,而且很快達到了一個較高的準確率,所以,使用finetune能使得模型更快進行訓練

法1 : 凍結卷積層——運行結果:
在這裏插入圖片描述
由上圖可知,通過凍結卷積層,在迭代過程中,卷積層的參數是不變的

法2 : conv 小學習率——運行結果:
在這裏插入圖片描述
通過卷積層使用較小的學習率訓練結果,這裏設置的是0.0001

法2 : conv 小學習率(學習率設置爲0)——運行結果:
在這裏插入圖片描述
可以看到,學習率設置爲0,卷積層的參數在訓練過程中是不變的,此時該方法的效果與法1相同

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章