一、Transfer Learning & Model Finetune
1.1 Transfer Learning
Transfer Learning:機器學習分支,研究源域(source domain)的知識如何應用到目標域(targetdomain)
傳統的機器學習:
對不同的任務分別訓練學習得到不同的learning system,即模型,如上圖有三個不同任務,就得到三個不同的模型
遷移學習:
先對源任務進行學習,得到知識,然後在目標任務中,會使用再源任務上學習得到的知識來學習訓練模型,也就是說該模型不僅用到了target tasks,也用到了source tasks
1.2 Model Finetune
1.2.1 Model Finetune概念
Model Finetune:模型的遷移學習
模型微調:
模型微調就是一個遷移學習的過程,模型中訓練學習得到的權值,就是遷移學習中所謂的知識,而這些知識是可以進行遷移的,把這些知識遷移到新任務中,這就完成了遷移學習
微調的原因:
在新任務中,數據量太小,不足以去訓練一個較大的模型,從而選擇Model Finetune去輔助訓練一個較好的模型,使得訓練更快
卷積神經網絡的遷移:
將卷積神經網絡分成兩部分:features extractor + classifier
- features extractor:模型的共性部分,通常對其進行保留
- classifier:根據不同任務要求對輸出層進行finetune
1.2.2 Model Finetune步驟
Model Finetune:
先進行模型微調,加載模型參數,並根據任務要求修改模型,此過程稱預訓練,然後進行正式訓練,此時要注意預訓練的參數的保持,具體步驟和方法如下
模型微調步驟:
- 獲取預訓練模型參數
- 加載模型( load_state_dict)
- 修改輸出層
模型微調訓練方法:
- 固定預訓練的參數,兩種方法:
- requires_grad =False
- lr=0
- Features Extractor部分設置較小學習率( params_group)
說明:
優化器中可以管理不同的參數組,這樣就可以爲不同的參數組設置不同的超參數,對Features Extractor部分設置較小學習率
二、Pytorch中的Finetune
2.1 Model Finetune實例
數據: https://download.pytorch.org/tutorial/hymenoptera_data.zip
模型: https://download.pytorch.org/models/resnet18-5c106cde.pth
2.1.1 目錄結構
模型和數據的存放位置如上圖所示
2.1.1 代碼詳解
my_dataset.py
# -*- coding: utf-8 -*-
import os
import random
from PIL import Image
from torch.utils.data import Dataset
random.seed(1)
rmb_label = {"1": 0, "100": 1}
class AntsDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.label_name = {"ants": 0, "bees": 1}
self.data_info = self.get_img_info(data_dir)
self.transform = transform
def __getitem__(self, index):
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img,label
def __len__(self):
return len(self.data_info)
def get_img_info(self, data_dir):
data_info = list()
for root, dirs, _ in os.walk(data_dir):
# 遍歷類別
for sub_dir in dirs:
img_names = os.listdir(os.path.join(root, sub_dir))
img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
# 遍歷圖片
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.join(root, sub_dir, img_name)
label = self.label_name[sub_dir]
data_info.append((path_img, int(label)))
if len(data_info) == 0:
raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(data_dir))
return data_info
common_tools.py
# -*- coding: utf-8 -*-
import torch
import random
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
def transform_invert(img_, transform_train):
"""
將data 進行反transfrom操作
:param img_: tensor
:param transform_train: torchvision.transforms
:return: PIL image
"""
if 'Normalize' in str(transform_train):
norm_transform = list(filter(lambda x: isinstance(x, transforms.Normalize), transform_train.transforms))
mean = torch.tensor(norm_transform[0].mean, dtype=img_.dtype, device=img_.device)
std = torch.tensor(norm_transform[0].std, dtype=img_.dtype, device=img_.device)
img_.mul_(std[:, None, None]).add_(mean[:, None, None])
img_ = img_.transpose(0, 2).transpose(0, 1) # C*H*W --> H*W*C
if 'ToTensor' in str(transform_train):
img_ = img_.detach().numpy() * 255
if img_.shape[2] == 3:
img_ = Image.fromarray(img_.astype('uint8')).convert('RGB')
elif img_.shape[2] == 1:
img_ = Image.fromarray(img_.astype('uint8').squeeze())
else:
raise Exception("Invalid img shape, expected 1 or 3 in axis 2, but got {}!".format(img_.shape[2]) )
return img_
def set_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
finetune_resnet18.py
# -*- coding: utf-8 -*-
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from tools.my_dataset import AntsDataset
from tools.common_tools import set_seed
import torchvision.models as models
import torchvision
BASEDIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 設置模型訓練的設備
print("use device :{}".format(device))
set_seed(1) # 設置隨機種子
label_name = {"ants": 0, "bees": 1}
# 參數設置
MAX_EPOCH = 25
BATCH_SIZE = 16
LR = 0.001
log_interval = 10
val_interval = 1
classes = 2
start_epoch = -1
lr_decay_step = 7
# ============================ step 1/5 數據 ============================
data_dir = os.path.join(BASEDIR, "..", "..", "data/hymenoptera_data")
train_dir = os.path.join(data_dir, "train")
valid_dir = os.path.join(data_dir, "val")
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
valid_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
# 構建MyDataset實例
train_data = AntsDataset(data_dir=train_dir, transform=train_transform)
valid_data = AntsDataset(data_dir=valid_dir, transform=valid_transform)
# 構建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
# ============================ step 2/5 模型 ============================
# 1/3 構建模型
resnet18_ft = models.resnet18()
# 2/3 加載參數
# flag = 0
flag = 1
if flag:
path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth")
state_dict_load = torch.load(path_pretrained_model)
resnet18_ft.load_state_dict(state_dict_load)
# 法1 : 凍結卷積層
flag_m1 = 0
# flag_m1 = 1
if flag_m1:
for param in resnet18_ft.parameters():
param.requires_grad = False
print("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.conv1.weight[0, 0, ...]))
# 3/3 替換fc層
num_ftrs = resnet18_ft.fc.in_features # 從原始的resnet18從獲取輸入的結點數
resnet18_ft.fc = nn.Linear(num_ftrs, classes)
resnet18_ft.to(device) # 將模型遷移到設置的設備上
# ============================ step 3/5 損失函數 ============================
criterion = nn.CrossEntropyLoss() # 選擇損失函數
# ============================ step 4/5 優化器 ============================
# 法2 : conv 小學習率
flag = 0
# flag = 1
if flag:
# 劃分模型參數爲兩個部分:resnet18_ft.fc.parameters()和base_params
fc_params_id = list(map(id, resnet18_ft.fc.parameters())) # 返回的是parameters的 內存地址
base_params = filter(lambda p: id(p) not in fc_params_id, resnet18_ft.parameters())
optimizer = optim.SGD([
{'params': base_params, 'lr': LR*0.1}, # 0
{'params': resnet18_ft.fc.parameters(), 'lr': LR}], momentum=0.9)
else:
optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9) # 選擇優化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1) # 設置學習率下降策略
# ============================ step 5/5 訓練 ============================
train_curve = list()
valid_curve = list()
for epoch in range(start_epoch + 1, MAX_EPOCH):
loss_mean = 0.
correct = 0.
total = 0.
resnet18_ft.train()
for i, data in enumerate(train_loader):
# forward
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device) # 訓練數據也要放到設置的設備上
outputs = resnet18_ft(inputs)
# backward
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()
# update weights
optimizer.step()
# 統計分類情況
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().cpu().sum().numpy()
# 打印訓練信息
loss_mean += loss.item()
train_curve.append(loss.item())
if (i+1) % log_interval == 0:
loss_mean = loss_mean / log_interval
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
loss_mean = 0.
# if flag_m1:
# print("epoch:{} conv1.weights[0, 0, ...] :\n {}".format(epoch, resnet18_ft.conv1.weight[0, 0, ...]))
scheduler.step() # 更新學習率
# validate the model
if (epoch+1) % val_interval == 0:
correct_val = 0.
total_val = 0.
loss_val = 0.
resnet18_ft.eval()
with torch.no_grad():
for j, data in enumerate(valid_loader):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = resnet18_ft(inputs)
loss = criterion(outputs, labels)
_, predicted = torch.max(outputs.data, 1)
total_val += labels.size(0)
correct_val += (predicted == labels).squeeze().cpu().sum().numpy()
loss_val += loss.item()
loss_val_mean = loss_val/len(valid_loader)
valid_curve.append(loss_val_mean)
print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_mean, correct_val / total_val))
resnet18_ft.train()
train_x = range(len(train_curve))
train_y = train_curve
train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由於valid中記錄的是epochloss,需要對記錄點進行轉換到iterations
valid_y = valid_curve
plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')
plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()
不進行finetune的運行結果:
不進行finetune,最終訓練結果,loss爲0.5729,accuracy爲67.50%
遷移學習加載參數後運行結果:
可以看到,通過遷移學習,加載已學習的參數後進行訓練,準確率是直接從百分之六十多開始增加,而且很快達到了一個較高的準確率,所以,使用finetune能使得模型更快進行訓練
法1 : 凍結卷積層——運行結果:
由上圖可知,通過凍結卷積層,在迭代過程中,卷積層的參數是不變的
法2 : conv 小學習率——運行結果:
通過卷積層使用較小的學習率訓練結果,這裏設置的是0.0001
法2 : conv 小學習率(學習率設置爲0)——運行結果:
可以看到,學習率設置爲0,卷積層的參數在訓練過程中是不變的,此時該方法的效果與法1相同