PyTorch與計算機視覺簡要總結

PyTorch與計算機視覺簡要總結

更新和原文見:
https://github.com/bat67/vision-with-pytorch
https://github.com/bat67/pytorch-tutorials-examples-and-books

目錄

import torch
import torch.nn as nn
from torchvision import datasets, models, transforms
model = models.resnet18(pretrained=False)

1. 使用預訓練好的 Resnet 網絡進行微調

讓我們首先研究這個模型的各個層,然後再決定要固定(freeze)哪些層。固定的意思是我們想要這些層的參數固定不變。微調簡單來說就是使用一個在大規模數據集上預訓練好模型在我們的目標數據集上接着訓練。當然,我們也可以不微調從零開始訓練,這意味的重新造輪子,後面會解釋爲什麼。

假設,我想訓練一個數據集來學習區分汽車和自行車。現在,我可以收集這兩個類別的圖像,並從頭開始訓練網絡。但是,考慮到現有的大部分工作,很容易找到一個訓練有素的模型來識別狗、貓和人。無可否認,這三種汽車看起來既不像汽車也不像自行車。然而,這總比什麼都沒有好。我們可以從這個模型開始,訓練它學習區分汽車和自行車。

好處有:

  • 它會更快,
  • 我們需要更少的貓和自行車的圖像。

如果對遷移學習感興趣的話,可以參考:http://cs231n.github.io/transfer-learning

現在,讓我們來看看resnet18的內容。爲此,我們使用.children()函數。這讓我們看看模型不同層的內容。然後,我們使用.parameters()函數訪問任意層的參數/權重。最後,每個參數都有一個屬性.requires_grad,它定義了一個參數是訓練的還是凍結的。默認情況下,它是True,網絡在每次迭代中都會更新它。如果將其設置爲False,則不更新它,並稱爲“凍結”。

child_counter = 0
for child in model.children():
    print(" child", child_counter, "is -")
    print(child)
    child_counter += 1

輸出:

child 0 is -
Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
 child 1 is -
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 child 2 is -
ReLU(inplace)
 child 3 is -
MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
 child 4 is -
Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)
 child 5 is -
Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (downsample): Sequential(
      (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): BasicBlock(
    (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)
 child 6 is -
Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
    (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (downsample): Sequential(
      (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): BasicBlock(
    (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
    (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)
 child 7 is -
Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (downsample): Sequential(
      (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): BasicBlock(
    (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)
 child 8 is -
AvgPool2d(kernel_size=7, stride=1, padding=0)
 child 9 is -
Linear(in_features=512, out_features=1000, bias=True)

現在,您可以看到一些子元素實際上是一大塊一大塊的,它們內部有層。要更深入地訪問一層,我們還可以在一個子對象上運行.children()

讓我們看看我們想要凍結所有參數直到Child 6的第一個BasicBlock。首先,讓我們查看一個參數並將其設置爲frozen

for child in model.children():
    for param in child.parameters():
        print("This is what a parameter looks like - \n",param)
        break
    break

輸出:

This is what a parameter looks like - 
 Parameter containing:
tensor([[[[ 0.0215,  0.0528,  0.0342,  ...,  0.0086, -0.0114,  0.0041],
          [ 0.0004,  0.0073,  0.0066,  ..., -0.0115, -0.0238, -0.0034],
          [-0.0017, -0.0078, -0.0114,  ...,  0.0236,  0.0338,  0.0080],
          ...,
          [ 0.0261,  0.0273, -0.0231,  ...,  0.0469, -0.0048,  0.0094],
          [-0.0027,  0.0285,  0.0030,  ..., -0.0260,  0.0206,  0.0365],
          [-0.0185,  0.0175, -0.0042,  ..., -0.0078, -0.0132, -0.0199]],

         [[ 0.0152, -0.0347,  0.0170,  ...,  0.0012,  0.0502, -0.0005],
          [-0.0185, -0.0213, -0.0167,  ...,  0.0144, -0.0169,  0.0038],
          [-0.0605, -0.0036, -0.0019,  ..., -0.0112,  0.0087,  0.0223],
          ...,
          [-0.0331, -0.0077,  0.0107,  ..., -0.0040, -0.0057,  0.0158],
          [-0.0075,  0.0082, -0.0406,  ...,  0.0102, -0.0123, -0.0018],
          [-0.0211,  0.0258, -0.0119,  ..., -0.0261,  0.0303,  0.0390]],

         [[-0.0280,  0.0115, -0.0209,  ...,  0.0384, -0.0050, -0.0155],
          [-0.0511,  0.0144, -0.0240,  ..., -0.0469, -0.0738, -0.0017],
          [-0.0222, -0.0044, -0.0265,  ..., -0.0671, -0.0002,  0.0076],
          ...,
          [-0.0162,  0.0129, -0.0302,  ..., -0.0197, -0.0143,  0.0035],
          [-0.0051, -0.0330, -0.0027,  ..., -0.0348,  0.0190,  0.0001],
          [-0.0299,  0.0401,  0.0162,  ..., -0.0360, -0.0065,  0.0054]]],


        [[[ 0.0118,  0.0433,  0.0365,  ...,  0.0402,  0.0078,  0.0029],
          [ 0.0050, -0.0310,  0.0113,  ...,  0.0267, -0.0045,  0.0193],
          [-0.0196, -0.0112,  0.0022,  ..., -0.0137, -0.0400,  0.0375],
          ...,
          [ 0.0151, -0.0172,  0.0516,  ...,  0.0283, -0.0392,  0.0039],
          [-0.0136, -0.0004,  0.0151,  ..., -0.0525, -0.0084, -0.0140],
          [-0.0740,  0.0254,  0.0247,  ..., -0.0185, -0.0250,  0.0213]],

         [[ 0.0161, -0.0205, -0.0411,  ...,  0.0010,  0.0293, -0.0255],
          [-0.0252, -0.0155,  0.0183,  ...,  0.0329,  0.0010,  0.0369],
          [-0.0136, -0.0192,  0.0033,  ...,  0.0413, -0.0121,  0.0059],
          ...,
          [-0.0174,  0.0145,  0.0107,  ...,  0.0018,  0.0450, -0.0040],
          [ 0.0829,  0.0208, -0.0108,  ..., -0.0291, -0.0362, -0.0288],
          [ 0.0026, -0.0257,  0.0152,  ...,  0.0356,  0.0076, -0.0048]],

         [[-0.0098, -0.0167,  0.0031,  ..., -0.0252,  0.0185,  0.0408],
          [ 0.0396,  0.0190,  0.0127,  ...,  0.0015,  0.0048, -0.0003],
          [ 0.0453, -0.0009, -0.0130,  ...,  0.0021, -0.0186,  0.0081],
          ...,
          [ 0.0187,  0.0306,  0.0213,  ...,  0.0210,  0.0307, -0.0107],
          [-0.0172, -0.0505, -0.0307,  ...,  0.0165,  0.0270, -0.0207],
          [ 0.0369, -0.0365, -0.0053,  ...,  0.0260,  0.0132, -0.0038]]],


        [[[-0.0273,  0.0146, -0.0249,  ...,  0.0076, -0.0065,  0.0224],
          [ 0.0089, -0.0119, -0.0368,  ...,  0.0084,  0.0072,  0.0114],
          [-0.0079, -0.0172, -0.0194,  ...,  0.0057, -0.0238,  0.0199],
          ...,
          [ 0.0344,  0.0110,  0.0175,  ...,  0.0069, -0.0197,  0.0363],
          [ 0.0353, -0.0064,  0.0152,  ...,  0.0181,  0.0489, -0.0113],
          [-0.0496, -0.0119,  0.0305,  ...,  0.0204,  0.0095, -0.0072]],

         [[ 0.0194, -0.0249,  0.0267,  ..., -0.0324,  0.0048, -0.0409],
          [-0.0170,  0.0314,  0.0284,  ..., -0.0135, -0.0624,  0.0074],
          [-0.0033, -0.0315, -0.0196,  ...,  0.0169, -0.0269, -0.0469],
          ...,
          [ 0.0071,  0.0211, -0.0085,  ...,  0.0080, -0.0157, -0.0132],
          [-0.0176,  0.0330,  0.0847,  ...,  0.0636, -0.0130, -0.0436],
          [ 0.0009, -0.0204, -0.0455,  ..., -0.0104,  0.0138, -0.0272]],

         [[ 0.0014,  0.0131,  0.0109,  ..., -0.0306, -0.0080, -0.0100],
          [-0.0447,  0.0356, -0.0170,  ..., -0.0196, -0.0394, -0.0030],
          [ 0.0133,  0.0217,  0.0251,  ...,  0.0014,  0.0168, -0.0128],
          ...,
          [ 0.0195, -0.0328,  0.0026,  ...,  0.0181, -0.0079,  0.0053],
          [ 0.0066, -0.0115, -0.0058,  ..., -0.0279,  0.0086, -0.0293],
          [ 0.0155, -0.0275,  0.0301,  ..., -0.0175,  0.0125,  0.0422]]],


        ...,


        [[[ 0.0085,  0.0286, -0.0193,  ...,  0.0575,  0.0624, -0.0180],
          [-0.0121,  0.0097, -0.0041,  ..., -0.0046, -0.0297, -0.0132],
          [ 0.0193,  0.0319, -0.0011,  ...,  0.0842,  0.0024,  0.0216],
          ...,
          [ 0.0030, -0.0581,  0.0155,  ...,  0.0116, -0.0309, -0.0021],
          [ 0.0102,  0.0271, -0.0054,  ..., -0.0410,  0.0046,  0.0071],
          [-0.0105, -0.0302, -0.0786,  ..., -0.0722, -0.0223,  0.0205]],

         [[ 0.0102, -0.0060,  0.0071,  ...,  0.0557, -0.0179, -0.0071],
          [ 0.0133,  0.0215, -0.0192,  ...,  0.0336,  0.0003, -0.0370],
          [ 0.0736,  0.0302,  0.0182,  ..., -0.0492, -0.0304, -0.0157],
          ...,
          [ 0.0178,  0.0133, -0.0096,  ..., -0.0147, -0.0118,  0.0109],
          [-0.0349,  0.0127, -0.0229,  ..., -0.0048, -0.0411,  0.0271],
          [-0.0292,  0.0062,  0.0508,  ..., -0.0150, -0.0263, -0.0092]],

         [[-0.0117,  0.0185, -0.0792,  ..., -0.0329, -0.0041, -0.0155],
          [ 0.0106,  0.0026,  0.0278,  ..., -0.0145, -0.0515, -0.0069],
          [ 0.0057,  0.0027,  0.0075,  ...,  0.0102, -0.0022, -0.0311],
          ...,
          [ 0.0505,  0.0170,  0.0238,  ..., -0.0346,  0.0114, -0.0517],
          [ 0.0036,  0.0283,  0.0039,  ..., -0.0094, -0.0055,  0.0056],
          [-0.0140, -0.0227, -0.0469,  ..., -0.0151, -0.0306, -0.0058]]],


        [[[-0.0027,  0.0185, -0.0161,  ...,  0.0008,  0.0461,  0.0178],
          [ 0.0129,  0.0056, -0.0212,  ...,  0.0129,  0.0171, -0.0352],
          [-0.0362, -0.0306,  0.0026,  ...,  0.0452,  0.0111, -0.0059],
          ...,
          [-0.0104, -0.0059, -0.0304,  ..., -0.0013,  0.0077,  0.0017],
          [-0.0069, -0.0100, -0.0133,  ..., -0.0098,  0.0053,  0.0027],
          [ 0.0108,  0.0390, -0.0113,  ..., -0.0073, -0.0112,  0.0154]],

         [[-0.0212,  0.0204,  0.0258,  ..., -0.0053,  0.0176, -0.0476],
          [ 0.0125,  0.0392,  0.0196,  ..., -0.0427, -0.0190, -0.0223],
          [-0.0120,  0.0075, -0.0316,  ...,  0.0149,  0.0112,  0.0178],
          ...,
          [-0.0185, -0.0145, -0.0032,  ..., -0.0054,  0.0354,  0.0254],
          [ 0.0082, -0.0540,  0.0323,  ..., -0.0058, -0.0363,  0.0051],
          [ 0.0162,  0.0447, -0.0112,  ..., -0.0144,  0.0131, -0.0733]],

         [[-0.0316,  0.0220, -0.0081,  ..., -0.0192, -0.0206, -0.0620],
          [-0.0737, -0.0132, -0.0192,  ..., -0.0172, -0.0007, -0.0463],
          [ 0.0073, -0.0223,  0.0055,  ...,  0.0385,  0.0187,  0.0185],
          ...,
          [-0.0256, -0.0416,  0.0338,  ..., -0.0006,  0.0518, -0.0334],
          [-0.0211, -0.0207,  0.0294,  ...,  0.0350,  0.0289,  0.0049],
          [-0.0070, -0.0264, -0.0694,  ...,  0.0305, -0.0267, -0.0405]]],


        [[[ 0.0054, -0.0061, -0.0005,  ..., -0.0087,  0.0184, -0.0149],
          [-0.0015,  0.0065,  0.0272,  ...,  0.0328, -0.0278, -0.0014],
          [ 0.0153,  0.0215, -0.0143,  ...,  0.0430, -0.0258,  0.0106],
          ...,
          [ 0.0265, -0.0427, -0.0455,  ..., -0.0402, -0.0110, -0.0325],
          [-0.0201,  0.0074,  0.0058,  ..., -0.0274,  0.0552,  0.0669],
          [-0.0108,  0.0201,  0.0139,  ..., -0.0397, -0.0451, -0.0204]],

         [[-0.0199,  0.0304,  0.0293,  ..., -0.0404,  0.0075,  0.0160],
          [-0.0208,  0.0019, -0.0206,  ...,  0.0091,  0.0293, -0.0191],
          [ 0.0391, -0.0032, -0.0169,  ..., -0.0520,  0.0202, -0.0186],
          ...,
          [ 0.0179, -0.0185,  0.0187,  ...,  0.0124, -0.0003, -0.0040],
          [-0.0037, -0.0505, -0.0034,  ...,  0.0233, -0.0063,  0.0286],
          [ 0.0469,  0.0474,  0.0398,  ...,  0.0244, -0.0162, -0.0182]],

         [[ 0.0022, -0.0071,  0.0236,  ...,  0.0192, -0.0027, -0.0632],
          [ 0.0016, -0.0199,  0.0360,  ...,  0.0125, -0.0014, -0.0034],
          [-0.0044,  0.0121,  0.0031,  ...,  0.0247,  0.0131,  0.0165],
          ...,
          [-0.0087, -0.0038, -0.0167,  ..., -0.0427, -0.0289, -0.0225],
          [-0.0392,  0.0018, -0.0007,  ...,  0.0366, -0.0056,  0.0160],
          [-0.0223,  0.0245,  0.0063,  ..., -0.0044, -0.0111, -0.0155]]]],
       requires_grad=True)

很明顯,訓練過程中會伴隨着大量的計算。現在,如果我們固定前6個child的參數設置爲凍結的話,訓練會得到很明顯的加速。現在,讓我們凍結到Child 6的第一個BasicBlock

child_counter = 0
for child in model.children():
    if child_counter < 6:
        print("child ",child_counter," was frozen")
        for param in child.parameters():
            param.requires_grad = False
    elif child_counter == 6:
        children_of_child_counter = 0
        for children_of_child in child.children():
            if children_of_child_counter < 1:
                for param in children_of_child.parameters():
                    param.requires_grad = False
                print('child ', children_of_child_counter, 'of child',child_counter,' was frozen')
            else:
                print('child ', children_of_child_counter, 'of child',child_counter,' was not frozen')
            children_of_child_counter += 1

    else:
        print("child ",child_counter," was not frozen")
    child_counter += 1

輸出:

child  0  was frozen
child  1  was frozen
child  2  was frozen
child  3  was frozen
child  4  was frozen
child  5  was frozen
child  0 of child 6  was frozen
child  1 of child 6  was not frozen
child  7  was not frozen
child  8  was not frozen
child  9  was not frozen

重要提示

既然已經固定了這個網絡部分的參數不變,接下來要做的事情就是使它順利跑起來。這就取決於你自己的優化器了。優化器是用來更新模型的參數的,通常,我們這麼來寫:

optimizer = torch.optim.RMSprop(model.parameters(), lr=0.1)

但是,這會給你一個錯誤,因爲這會試圖更新模型的所有參數。但是,您已經將其中一些設置爲凍結!因此,只傳遞仍在更新的項的方法如下:

optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)

2. 模型的保存和加載

PyTorch中保存模型的主要方式有兩種。建議使用“狀態字典”(state dictionaries)。它們更快,需要更小的空間。基本上,他們不知道模型結構,他們只是參數/權重的值。所以,必須重新創建模型的結構並且載入這些參數。架構的聲明與我們在上面所做的一樣。

# Let's assume we will save/load from a path MODEL_PATH

# Saving a Model
torch.save(model.state_dict(), MODEL_PATH)

# Loading the model.

# First create a model and define it's architecture as done above in this notebook. If you want a custom architecture.
# read below it's been covered below.
checkpoint = torch.load(MODEL_PATH)
model.load_state_dict(checkpoint)

3. 修改、刪除或增加最後一層

大多數使用pytorch的人(特別是如果他們用過Keras)都不喜歡這樣的事實:他們不能通過.pop()刪除最後一層。那麼,讓我們來看看這些事情是如何做到的。

修改最後一層

# Load the model
model = models.resnet18(pretrained = False)

# Get number of parameters going in to the last layer. 
# we need this to change the final layer. 
num_final_in = model.fc.in_features

# The final layer of the model is model.fc so we can basically just overwrite it 
# to have the output = number of classes we need. Say, 300 classes.
NUM_CLASSES = 300
model.fc = nn.Linear(num_final_in, NUM_CLASSES)

刪除最後一層 (通常,在需要一個層的參數時)

# Load the model
model = models.resnet18(pretrained = False)

我們可以像以前一樣使用model.children()來獲取這些層。然後,我們可以使用list()命令將其轉換爲一個列表。然後,我們可以通過索引列表來刪除最後一層。最後,我們可以使用PyTorch函數nn.sequence()將修改後的列表堆疊到一個新模型中。您可以以任何您想要的方式編輯列表。也就是說,如果你想要從第三個圖層中獲得圖像的特徵,你可以刪除最後兩層。您甚至可以從模型的中間刪除層。但顯然,這將導致不正確的數量的特徵進入層後,因爲大多數層改變圖像的大小。在這種情況下,您可以索引模型的特定層並覆蓋它,就像上面展示的那樣。

new_model = nn.Sequential(*list(model.children())[:-1])
new_model_2_removed = nn.Sequential(*list(model.children())[:-2])

增加層

比方說,你想給我們現在的模型添加一個完全連接的層。一種顯而易見的方法是編輯上面討論的列表,並將其附加到另一層。然而,通常我們訓練了這樣一個模型,想看看是否可以加載該模型,並在其上添加一個新層。如上所述,加載的模型應該與保存的模型具有相同的體系結構,因此我們不能使用list方法。我們需要在上面添加圖層。在PyTorch中實現這一點的方法很簡單——我們只需要創建一個自定義模型!這就把我們帶到了下一節——創建自定義模型。

4. 自定義模型 : 結合 Section 1-3,在模型頭部添加層

讓我們創建一個自定義模型。如上所述,我們將從一個預先訓練的網絡加載一半的模型。這看起來很複雜,對吧?一半的模型是經過訓練的,一半是新的。此外,我們希望其中一些層固定起來。有些是可更新的。實際上,一旦您完成了這些,您就可以使用PyTorch中的模型體系結構做任何事情。

# Some imports first
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torch
from torchvision import datasets, models, transforms

# New models are defined as classes. Then, when we want to create a model we create an object instantiating this class.
class Resnet_Added_Layers_Half_Frozen(nn.Module):
    def __init__(self,LOAD_VIS_URL=None):
        super(ResnetCombinedFull2, self).__init__()
    
         # Start with half the resnet model, swap out the final layer because that's the model we had defined above. 
        model = models.resnet18(pretrained = False)
        num_final_in = model.fc.in_features
        model.fc = nn.Linear(num_final_in, 300)
        
        # Now that the architecture is defined same as above, let's load the model we would have trained above. 
        checkpoint = torch.load(MODEL_PATH)
        model.load_state_dict(checkpoint)
        
        
        # Let's freeze the same as above. Same code as above without the print statements
        child_counter = 0
        for child in model.children():
            if child_counter < 6:
                for param in child.parameters():
                    param.requires_grad = False
            elif child_counter == 6:
                children_of_child_counter = 0
                for children_of_child in child.children():
                    if children_of_child_counter < 1:
                        for param in children_of_child.parameters():
                            param.requires_grad = False
                    else:
                        children_of_child_counter += 1

            else:
                print("child ",child_counter," was not frozen")
            child_counter += 1
        
        # Now, let's define new layers that we want to add on top. 
        # Basically, these are just objects we define here. The "adding on top" is defined by the forward()
        # function which decides the flow of the input data into the model.
        
        # NOTE - Even the above model needs to be passed to self.
        self.vismodel = nn.Sequential(*list(model.children()))
        self.projective = nn.Linear(512,400)
        self.nonlinearity = nn.ReLU(inplace=True)
        self.projective2 = nn.Linear(400,300)
        
    
    # The forward function defines the flow of the input data and thus decides which layer/chunk goes on top of what.
    def forward(self,x):
        x = self.vismodel(x)
        x = torch.squeeze(x)
        x = self.projective(x)
        x = self.nonlinearity(x)
        x = self.projective2(x)
        return x

5. 自定義損失函數

現在我們的模型已經就緒,我們可以加載任何東西並創建任何我們想要的架構。這就給我們在整個流程中留下了兩個重要的組件——加載數據和訓練部分。讓我們來看看訓練部分。這一步中最重要的兩個組件是優化器(optimizer)和損失函數(loss function)。損失函數量化我們現有的模型離我們想要的位置有多遠,而優化器決定如何更新參數,以便我們可以最小化損失。

有時,我們需要定義自己的損失函數。這裏有一些關於這個需要知道的事情:

  • 自定義損失函數也是使用自定義類定義的。他們和自定義模型一樣,繼承了torch.nn.Module。
  • 通常,我們需要改變一個輸入的維度。這可以使用view()函數來完成。
  • 如果我們想給張量增加一個維數,使用unsqueeze()函數。
  • 最後由損失函數返回的值必須是標量值。不是一個向量/張量。

這裏我展示一個定製的損失稱爲Regress_Loss,2個輸入類型的輸入x和y,然後它會reshape x和y,最後返回損失通過計算L2損失。這是一個經常在訓練網絡會遇到的事情。

假設x形狀爲(5,10),y形狀爲(5,5,10)。因此,我們需要給x增加一個維度,然後沿着增加的維度重複來匹配y的維度,那麼(x-y)就是形狀(5,5,10)我們需要對所有的三個維度相加,也就是三個torch.sum()來得到一個標量。

class Regress_Loss(torch.nn.Module):
    
    def __init__(self):
        super(Regress_Loss,self).__init__()
        
    def forward(self,x,y):
        y_shape = y.size()[1]
        x_added_dim = x.unsqueeze(1)
        x_stacked_along_dimension1 = x_added_dim.repeat(1,NUM_WORDS,1)
        diff = torch.sum((y - x_stacked_along_dimension1)**2,2)
        totloss = torch.sum(torch.sum(torch.sum(diff)))
        return totloss
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章