遷移學習 Transfer Learning—通俗易懂地介紹(常見網絡模型pytorch實現)

前言

你會發現聰明人都喜歡”偷懶”, 因爲這樣的偷懶能幫我們節省大量的時間, 提高效率. 還有一種偷懶是 “站在巨人的肩膀上”. 不僅能看得更遠, 還能看到更多. 這也用來表達我們要善於學習先輩的經驗, 一個人的成功往往還取決於先輩們累積的知識. 這句話, 放在機器學習中, 這就是今天要說的遷移學習了, transfer learning.

什麼是遷移學習?

遷移學習通俗來講,就是運用已有的知識來學習新的知識,核心是找到已有知識和新知識之間的相似性,用成語來說就是舉一反三。由於直接對目標域從頭開始學習成本太高,我們故而轉向運用已有的相關知識來輔助儘快地學習新知識。比如,已經會下中國象棋,就可以類比着來學習國際象棋;已經會編寫Java程序,就可以類比着來學習C#;已經學會英語,就可以類比着來學習法語;等等。世間萬事萬物皆有共性,如何合理地找尋它們之間的相似性,進而利用這個橋樑來幫助學習新知識,是遷移學習的核心問題。

爲什麼需要遷移學習?

在這裏插入圖片描述
現在的機器人視覺已經非常先進了, 有些甚至超過了人類. 99.99%的識別準確率都不在話下. 這樣的成功, 依賴於強大的機器學習技術, 其中, 神經網絡成爲了領軍人物. 而 CNN 等, 像人一樣擁有千千萬萬個神經聯結的結構, 爲這種成功貢獻了巨大力量. 但是爲了更厲害的 CNN, 我們的神經網絡設計, 也從簡單的幾層網絡, 變得越來越多, 越來越多, 越來越多… 爲什麼會越來越多?

因爲計算機硬件, 比如 GPU 變得越來越強大, 能夠更快速地處理龐大的信息. 在同樣的時間內, 機器能學到更多東西. 可是, 不是所有人都擁有這麼龐大的計算能力. 而且有時候面對類似的任務時, 我們希望能夠借鑑已有的資源.

如何做遷移學習?

在這裏插入圖片描述這就好比, Google 和百度的關係, facebook 和人人的關係, KFC 和 麥當勞的關係, 同一類型的事業, 不用自己完全從頭做, 借鑑對方的經驗, 往往能節省很多時間. 有這樣的思路, 我們也能偷偷懶, 不用花時間重新訓練一個無比龐大的神經網絡, 借鑑借鑑一個已經訓練好的神經網絡就行.

在這裏插入圖片描述比如這樣的一個神經網絡, 我花了兩天訓練完之後, 它已經能正確區分圖片中具體描述的是男人, 女人還是眼鏡. 說明這個神經網絡已經具備對圖片信息一定的理解能力. 這些理解能力就以參數的形式存放在每一個神經節點中. 不巧, 領導下達了一個緊急任務,
在這裏插入圖片描述要求今天之內訓練出來一個預測圖片裏實物價值的模型. 我想這可完蛋了, 上一個圖片模型都要花兩天, 如果要再搭個模型重新訓練, 今天肯定出不來呀.

這時, 遷移學習來拯救我了. 因爲這個訓練好的模型中已經有了一些對圖片的理解能力, 而模型最後輸出層的作用是分類之前的圖片, 對於現在計算價值的任務是用不到的, #所以我將最後一層替換掉, 變爲服務於現在這個任務的輸出層. #接着只訓練新加的輸出層, 讓理解力保持始終不變. 前面的神經層龐大的參數不用再訓練, 節省了我很多時間, 我也在一天時間內, 將這個任務順利完成.

在這裏插入圖片描述
但並不是所有時候我們都需要遷移學習. 比如神經網絡很簡單, 相比起計算機視覺中龐大的 CNN 或者語音識別的 RNN, 訓練小的神經網絡並不需要特別多的時間, 我們完全可以直接重頭開始訓練. 從頭開始訓練也是有好處的.

在這裏插入圖片描述如果固定住之前的理解力, 或者使用更小的學習率來更新借鑑來的模型, 就變得有點像認識一個人時的第一印象, 如果遷移前的數據和遷移後的數據差距很大, 或者說我對於這個人的第一印象和後續印象差距很大, 我還不如不要管我的第一印象, 同理, 這時, 遷移來的模型並不會起多大作用, 還可能干擾我後續的決策.

遷移學習的限制

比如說,我們不能隨意移除預訓練網絡中的卷積層。但由於參數共享的關係,我們可以很輕鬆地在不同空間尺寸的圖像上運行一個預訓練網絡。這在卷積層和池化層和情況下是顯而易見的,因爲它們的前向函數(forward function)獨立於輸入內容的空間尺寸。在全連接層(FC)的情形中,這仍然成立,因爲全連接層可被轉化成一個卷積層。所以當我們導入一個預訓練的模型時,網絡結構需要與預訓練的網絡結構相同,然後再針對特定的場景和任務進行訓練。

常見的遷移學習方式:

  1. 載權重後訓練所有參數
  2. 載入權重後只訓練最後幾層參數
  3. 載入權重後在原網絡基礎上再添加一層全鏈接層,僅訓練最後一個全鏈接層

衍生

在這裏插入圖片描述瞭解了一般的遷移學習玩法後, 我們看看前輩們還有哪些新玩法. 多任務學習, 或者強化學習中的 learning to learn, 遷移機器人對運作形式的理解, 解決不同的任務. 炒個蔬菜, 紅燒肉, 番茄蛋花湯雖然菜色不同, 但是做菜的原則是類似的.
在這裏插入圖片描述
又或者 google 的翻譯模型, 在某些語言上訓練, 產生出對語言的理解模型, 將這個理解模型當做遷移模型在另外的語言上訓練. 其實說白了, 那個遷移的模型就能看成機器自己發明的一種只有它自己才能看懂的語言. 然後用自己的這個語言模型當成翻譯中轉站, 將某種語言轉成自己的語言, 然後再翻譯成另外的語言. 遷移學習的腦洞還有很多, 相信這種站在巨人肩膀上繼續學習的方法, 還會帶來更多有趣的應用.

使用圖像數據進行遷移學習

  • 牛津 VGG 模型(http://www.robots.ox.ac.uk/~vgg/research/very_deep/)
  • 谷歌 Inception模型(https://github.com/tensorflow/models/tree/master/inception)
  • 微軟 ResNet 模型(https://github.com/KaimingHe/deep-residual-networks)

可以在 Caffe Model Zoo(https://github.com/BVLC/caffe/wiki/Model-Zoo)中找到更多的例子,那裏分享了很多預訓練的模型。

實例:

注:如何獲取官方的.pth文件,以resnet爲例子

import torchvision.models.resnet

在腳本中輸入以上代碼,將鼠標對住resnet並按ctrl鍵,發現改變顏色,點擊進入resnet.py腳本,在最開始有url,如下圖所示
在這裏插入圖片描述選擇你要下載的模型,copy到瀏覽器即可,若是覺得慢可以用迅雷等等。

ResNet
詳細講解在這篇博文裏:ResNet——CNN經典網絡模型詳解(pytorch實現)

#train.py

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import json
import matplotlib.pyplot as plt
import os
import torch.optim as optim
from model import resnet34, resnet101
import torchvision.models.resnet


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#來自官網參數
    "val": transforms.Compose([transforms.Resize(256),#將最小邊長縮放到256
                               transforms.CenterCrop(224),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}


data_root = os.getcwd()
image_path = data_root + "/flower_data/"  # flower data set path

train_dataset = datasets.ImageFolder(root=image_path + "train",
                                     transform=data_transform["train"])
train_num = len(train_dataset)

# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)

batch_size = 16
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0)

validate_dataset = datasets.ImageFolder(root=image_path + "/val",
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=False,
                                              num_workers=0)
net = resnet34()
# net = resnet34(num_classes=5)
# load pretrain weights

model_weight_path = "./resnet34-pre.pth"
missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)#載入模型參數

# for param in net.parameters():
#     param.requires_grad = False
# change fc layer structure

inchannel = net.fc.in_features
net.fc = nn.Linear(inchannel, 5)


net.to(device)

loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)

best_acc = 0.0
save_path = './resNet34.pth'
for epoch in range(3):
    # train
    net.train()
    running_loss = 0.0
    for step, data in enumerate(train_loader, start=0):
        images, labels = data
        optimizer.zero_grad()
        logits = net(images.to(device))
        loss = loss_function(logits, labels.to(device))
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        # print train process
        rate = (step+1)/len(train_loader)
        a = "*" * int(rate * 50)
        b = "." * int((1 - rate) * 50)
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")
    print()

    # validate
    net.eval()
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))  # eval model only have last output layer
            # loss = loss_function(outputs, test_labels)
            predict_y = torch.max(outputs, dim=1)[1]
            acc += (predict_y == val_labels.to(device)).sum().item()
        val_accurate = acc / val_num
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
              (epoch + 1, running_loss / step, val_accurate))

print('Finished Training')

在這裏插入圖片描述未使用遷移學習
在這裏插入圖片描述VGG16

#train.py

import torch.nn as nn
from torchvision import transforms, datasets
import json
import os
import torch.optim as optim
from model import vgg
import torch
import time
import torchvision.models.vgg
from torchvision import models

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

#數據預處理,從頭

data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),


    "val": transforms.Compose([transforms.Resize((224, 224)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}


data_root = os.path.abspath(os.path.join(os.getcwd(), "../../.."))  # get data root path
image_path = data_root + "/data_set/flower_data/"  # flower data set pathh

train_dataset = datasets.ImageFolder(root=image_path + "/train",
                                     transform=data_transform["train"])
train_num = len(train_dataset)


# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)


batch_size = 20
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0)

validate_dataset = datasets.ImageFolder(root=image_path + "val",
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=False,
                                              num_workers=0)

# test_data_iter = iter(validate_loader)
# test_image, test_label = test_data_iter.next()

# model
# = models.vgg16(pretrained=True)

#
# model_name = "vgg16"
# net = vgg(model_name=model_name, init_weights=True)


# load pretrain weights
net = models.vgg16(pretrained=False)
pre = torch.load("./vgg16.pth")
net.load_state_dict(pre)

for parma in net.parameters():
    parma.requires_grad = False


net.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 4096),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(4096, 4096),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(4096, 5))

# model_weight_path = "./vgg16.pth"
# missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)#載入模型參數

# # for param in net.parameters():
# #     param.requires_grad = False
# # change fc layer structure
#
# inchannel = 512
# net.classifier = nn.Linear(inchannel, 5)

loss_function = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(net.classifier.parameters(), lr=0.001)

# loss_function = nn.CrossEntropyLoss()
# optimizer = optim.Adam(net.parameters(), lr=0.0001) #learn rate
net.to(device)

best_acc = 0.0
#save_path = './{}Net.pth'.format(model_name)
save_path = './vgg16Net.pth'
for epoch in range(15):
    # train
    net.train()
    running_loss = 0.0 #統計訓練過程中的平均損失

    t1 = time.perf_counter()
    for step, data in enumerate(train_loader, start=0):
        images, labels = data
        optimizer.zero_grad()
        #with torch.no_grad(): #用來消除驗證階段的loss,由於梯度在驗證階段不能傳回,造成梯度的累計
        outputs = net(images.to(device))
        loss = loss_function(outputs, labels.to(device))  #得到預測值與真實值的一個損失

        loss.backward()
        optimizer.step()#更新結點參數

        # print statistics
        running_loss += loss.item()
        # print train process
        rate = (step + 1) / len(train_loader)
        a = "*" * int(rate * 50)
        b = "." * int((1 - rate) * 50)
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
    print()
    print(time.perf_counter() - t1)

    # validate
    net.eval()
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():#不去跟蹤損失梯度
        for val_data in validate_loader:
            val_images, val_labels = val_data
            #optimizer.zero_grad()
            outputs = net(val_images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]
            acc += (predict_y == val_labels.to(device)).sum().item()
        val_accurate = acc / val_num
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
              (epoch + 1, running_loss / step, val_accurate))

print('Finished Training')

densenet121

#train.py

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import json
import matplotlib.pyplot as plt
from model import densenet121
import os
import torch.optim as optim
import torchvision.models.densenet
import torchvision.models as models

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#來自官網參數
    "val": transforms.Compose([transforms.Resize(256),#將最小邊長縮放到256
                               transforms.CenterCrop(224),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}


data_root = os.path.abspath(os.path.join(os.getcwd(), "../../.."))  # get data root path
image_path = data_root + "/data_set/flower_data/"  # flower data set path

train_dataset = datasets.ImageFolder(root=image_path + "train",
                                     transform=data_transform["train"])
train_num = len(train_dataset)

# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)

batch_size = 16
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0)

validate_dataset = datasets.ImageFolder(root=image_path + "/val",
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=False,
                                              num_workers=0)

#遷移學習
net = models.densenet121(pretrained=False)
model_weight_path="./densenet121-a.pth"
missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict= False)

inchannel = net.classifier.in_features
net.classifier = nn.Linear(inchannel, 5)
net.to(device)

loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)

#普通

# model_name = "densenet121"
# net = densenet121(model_name=model_name, num_classes=5)

best_acc = 0.0
save_path = './densenet121.pth'
for epoch in range(12):
    # train
    net.train()
    running_loss = 0.0
    for step, data in enumerate(train_loader, start=0):
        images, labels = data
        optimizer.zero_grad()
        logits = net(images.to(device))
        loss = loss_function(logits, labels.to(device))
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        # print train process
        rate = (step+1)/len(train_loader)
        a = "*" * int(rate * 50)
        b = "." * int((1 - rate) * 50)
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")
    print()

    # validate
    net.eval()
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))  # eval model only have last output layer
            # loss = loss_function(outputs, test_labels)
            predict_y = torch.max(outputs, dim=1)[1]
            acc += (predict_y == val_labels.to(device)).sum().item()
        val_accurate = acc / val_num
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
              (epoch + 1, running_loss / step, val_accurate))

print('Finished Training')

使用
在這裏插入圖片描述
參考自:莫煩

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章