Meta Learning技術 MAML

Learning to learn——Meta learning

Meta Learning 最常被用來解決少樣本(Few-Shot)的問題,在這邊我們介紹一篇經典的論文 Model-Agnostic Meta-Learning(MAML)。由題目可知他是一種「與模型無關的」元學習,亦即這種方法可以匹配任何使用梯度下降算法(Gradient Descent)訓練​​的模型,並能應用於各種不同的學習問題,如分類、迴歸和強化學習等。

MAML算法的目的

在 MAML 中,其目標在於一次看過多種任務(task),並希望可以學到一個可以找到所有任務「本質」的模型。舉例來說,我們小的時候學會寶特瓶可以一手握着瓶身,另一手將瓶蓋轉開;而當我們接觸到一個裝糖果的玻璃罐時,我們察覺玻璃罐與保特瓶相似的本質,因而有辦法套用既往的知識快速的移轉到新的任務上,而MAML便是在學這個過程,在遍覽多種任務後,學習一組對任務敏感的參數,當新任務進來時能快速的將先驗知識移轉到新任務中。
相對於deep learning在一個task(任務)中通過對樣本的學習以對新樣本做出判斷,元學習的目標可以看做是將task視作樣本,通過對多個task的學習,以使元模型(meta-learner)能夠對新的task做出快速而準確的學習。具體來說,就是訓練能"對特定的task產生特定的高效學習算法的算法"。至於MAML,則是嘗試訓練一個最簡單的算法——參數初始化。MAML希望訓練一組初始化參數,通過在初始參數的基礎上進行一或多步的梯度調整,來達到僅用少量數據就能快速適應新task的目的。當我們通過MAML得到了一組蠻不錯的參數,之後在類似的任務中,這組參數將會提供很好的模型初始迭代點。

算法步驟

在這裏插入圖片描述

  1. Sample batch size of tasks:首先會從meta-training裏面篩選一個batch size的training data。
  2. Evaluate gradient and compute adapted parameter:對 training data 中每一個 task 以及其對應的 label 計算屬於每個 Task 的 gradient 與更新後的 model 參數。
  3. Update the model:當有了每個task 利用training data of meta-train得到的新模型參數後,可以利用test data of meta-train驗證,並且加總所有任務的loss,對原本模型參數微分並真正的更新一次參數。
    如何評估一組初始化參數的好壞呢?最直覺的想法自然是用它和task的訓練集來訓練模型,看最後得到的正確率和所需要的迭代次數。但是一般的深度模型訓練常常要花費幾萬次的迭代來得到一個可靠的解,儘管我們可以像RNN的BPTT算法一樣把這幾萬次的過程中每一步的參數對應的梯度都考慮進去,進而更新初始化參數,但是這個過程需要的時間和空間開銷都大得驚人,我們有一種更高效的方式,就是隻進行一次參數更新,用這時的參數來計算誤差,更新初始化參數。
    解這樣的優化問題,得到的將是一個"在所有任務上,經過一次梯度下降更新後,total loss最小的初始化參數"。儘管我們不能說這是一個最好的初始化參數,但是我們可以相信這個參數將會幫助我們訓練更多的類似task。
    具體到代碼實現,考慮到task可能非常多,則是一般採取每次隨機抽取一個task,把參數代入模型,迭代更新一次;更新到第二次時,用這個Δ直接更新我們的初始化參數。
    這裏使用了一種近似,設初始參數爲Φ,則單次更新後的模型參數爲
    θ=ΦαL1Φ θ=Φ-\alpha \frac{\partial {L_1}}{\partial {Φ}}
    其中L1是第一次計算loss時的損失函數(損失函數會隨着參數變化而變化)。當計算第二次更新時,我們要計算這時的Φ關於Loss2的導數,涉及到二階導數,爲了快速計算,我們直接用
    L2Φ=L2θθΦ=L2θ(12LΦ2)L2θ \frac{\partial {L_2}}{\partial {Φ}} = \frac{\partial {L_2}}{\partial {θ}} \frac{\partial {θ}}{\partial {Φ}}= \frac{\partial {L_2}}{\partial {θ}} (1-\frac{\partial ^2{L}}{\partial {Φ}^2}) \approx \frac{\partial {L_2}}{\partial {θ}}
    的一階近似扔掉二階導數,這樣計算就變得簡單很多,每次我們只需要把初始參數(Meta)用於模型初始化,在訓練集上訓練一次更新參數,然後在計算第二次導數時,把這個導數拿出來用於更新我們的初始參數(Meta).

代碼實現

論文上給出了一組非常簡單的訓練任務集,我們要實現的就是生成 a∗sin(x+b) 的數據集,其中a,b可以調整以得到多種相似但不相同的任務。a,b 的範圍是(0,1.5) (0,2π),每個數據集有10個點。在這些數據集中,訓練一個最契合所有任務的初始化值。
(先調包,我可懶得自己寫梯度下降)

import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import copy
import matplotlib.pyplot as plt

然後我們設計一個函數產生我們需要的數據集,通過隨機產生a和b,就能獲得多組不同的tasks。我們這裏設置每個任務的數據size爲10

device = 'cpu'
def meta_task_data(seed = 0, a_range=[0.1, 5], b_range = [0, 2*np.pi], task_num = 100,
                   n_sample = 10, sample_range = [-5, 5], plot = False):
    np.random.seed = seed
    a_s = np.random.uniform(low = a_range[0], high = a_range[1], size = task_num)
    b_s = np.random.uniform(low = b_range[0], high = b_range[1], size = task_num)
    total_x = []
    total_y = []
    label = []
    for t in range(task_num):
        x = np.random.uniform(low = sample_range[0], high = sample_range[1], size = n_sample)
        total_x.append(x)
        total_y.append( a_s[t]*np.sin(x+b_s[t]) )
        label.append('{:.3}*sin(x+{:.3})'.format(a_s[t], b_s[t]))
    if plot:
        plot_x = [np.linspace(-5, 5, 1000)]
        plot_y = []
        for t in range(task_num):
            plot_y.append( a_s[t]*np.sin(plot_x+b_s[t]) ) 
        return total_x, total_y, plot_x, plot_y, label
    else:
        return total_x, total_y, label

bsz = 1
train_x, train_y, train_label = meta_task_data() 
train_x = torch.Tensor(train_x).unsqueeze(-1) # add one dim
train_y = torch.Tensor(train_y).unsqueeze(-1)
train_dataset = data.TensorDataset(train_x, train_y)
train_loader = data.DataLoader(dataset=train_dataset, batch_size=bsz, shuffle=False)

test_x, test_y, plot_x, plot_y, test_label = meta_task_data(task_num=1, n_sample = 10, plot=True)  
test_x = torch.Tensor(test_x).unsqueeze(-1) # add one dim
test_y = torch.Tensor(test_y).unsqueeze(-1) # add one dim
plot_x = torch.Tensor(plot_x).unsqueeze(-1) # add one dim
test_dataset = data.TensorDataset(test_x, test_y)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=bsz, shuffle=False)  

有了上面的推導,實現起來也並不困難。我們知道了只需要用第二次梯度下降的更新直接應用到meta上,就可以另外找一個變量把meta存起來,每次要用就用它初始化模型。在第一次和第二次更新參數時,把參數記錄下來取差值,直接加到meta上,就這麼簡單。
在此之前還要先定義模型,我們定義1-50-50-1的全連接網用來處理上面的所有任務。

# 定義模型,輸入輸出爲1維,雙隱層,隱層單元50個

class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        self.fc1 = nn.Linear(1, 50) 
        self.sig1 = nn.Sigmoid()
        self.fc2 = nn.Linear(50, 50)
        self.sig2 = nn.Sigmoid()
        self.fc3 = nn.Linear(50, 1)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.sig1(out)
        out = self.fc2(out)
        out = self.sig2(out)
        out = self.fc3(out)
        return out

爲了更方便地進行張量加減,我這裏把meta用一個一維的張量表示;如此,我們要額外定義從一維張量生成模型,以及從模型獲得一維張量的函數。

def model_to_array(model):
    '''
    這個函數把所有模型參數一字排開,輸出一個1維的tensor
    '''
    s_dict = model.state_dict()
    return torch.cat(
        (
        s_dict['fc1.weight'].view(-1),
        s_dict['fc1.bias'].view(-1),
        s_dict['fc2.weight'].view(-1),
        s_dict['fc2.bias'].view(-1),
        s_dict['fc3.weight'].view(-1),
        s_dict['fc3.bias'].view(-1))
    )

def array_to_model(model,arr):
    '''
    這個函數把1維的tensor的數據寫回model的dict中
    '''
    indice = 0
    s_dict = model.state_dict()
    for name,param in s_dict.items():
        length = torch.prod(torch.tensor(param.shape))
        s_dict[name] = arr[indice:indice+length].view(param.shape)
        indice += length
    model.load_state_dict(s_dict)

開始訓練,我們把所有的tasks過5000個epoch,使用隨機梯度下降,更新meta和更新模型參數的學習率都設爲0.01

model = net() # 正式的模型,用於在各個task上測試
protortype_prarms = model_to_array(model) # 我們要更新的原型參數,也就是MAML要訓練的參數

EPOCH = 5000
Prototype_LR = 0.01
Training_LR = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=Training_LR)
proto_optimizer = torch.optim.SGD(model.parameters(), lr=Prototype_LR)
loss_func = nn.MSELoss()

for epoch in range(EPOCH):
    total_loss = 0
    total_times = 0
    for step, (X,y) in enumerate(train_loader):
        X = X.view(10,1)
        y = y.view(10,1)
        # 把prototype的參數導入模型,作爲初始化
        array_to_model(model,protortype_prarms)
        # 先計算一次,更新一次參數
        yhat = model(X)
        loss = loss_func(yhat,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        params_first_step = model_to_array(model)
        
        # 然後做第二次計算,再更新一次參數,把這次更新的差值用於更新原型參數
        yhat = model(X)
        loss = loss_func(yhat,y)
        total_loss += loss.item()
        proto_optimizer.zero_grad()
        loss.backward()
        proto_optimizer.step()
        params_second_step = model_to_array(model)
        
        total_times += 1
        
        # 計算差值,更新protortype_prarms
        protortype_prarms += params_second_step
        protortype_prarms -= params_first_step
    if (epoch+1)%20==0:
        print("Epoch %d, loss:%.2f"%(epoch+1,total_loss/total_times))

我們可以畫圖來看一看meta的參數有沒有發揮作用

# 我們可以用兩張圖來看一看meta的參數有沒有發揮作用
optimizer = torch.optim.SGD(model.parameters(), lr=Training_LR)

fig = plt.figure(figsize = [9.6,7.2])
ax = plt.subplot(111)
plot_x1 = plot_x.squeeze().numpy()
ax.scatter(test_x.numpy().squeeze(), test_y.numpy().squeeze())
ax.plot(plot_x1, plot_y[0].squeeze(),label = 'origin')
# 丟入without train的model看一看
plot_y_without_train = model(plot_x.view(1000,1))
ax.plot(plot_x1, plot_y_without_train.detach().numpy().squeeze(),label = 'meta')
# train一個step,再觀察輸出
yhat = model(test_x[0])
loss = loss_func(yhat,test_y[0])
optimizer.zero_grad()
loss.backward()
optimizer.step()
plot_y_with_one_step = model(plot_x.view(1000,1))
ax.plot(plot_x1, plot_y_with_one_step.detach().numpy().squeeze(),label = '1 step')
# train 10個step,再觀察輸出
for step in range(10):
    yhat = model(test_x[0])
    loss = loss_func(yhat,test_y[0])
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
plot_y_with_ten_step = model(plot_x.view(1000,1))
ax.plot(plot_x1, plot_y_with_ten_step.detach().numpy().squeeze(),label = '10 step')
ax.legend()

在這裏插入圖片描述
初始化後的函數已經有點正弦函數的樣子了,只需要再train一個step,就能讓數據點和網絡輸出相當接近,如果train10個epoch,就基本完全收斂在所有的數據點上。這是一般的參數初始化完全無法做到的,這就是meta學習到的知識。
更直觀的,我們畫出normal的初始化和meta的初始化的loss圖。
在這裏插入圖片描述
在類似的任務上有更快的收斂速度,就是meta-learning的力量。

小結

懶得寫了,有人看記得點個讚唄

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章