Elastic Weight Consolidation(EWC) for Life long Learning

Life Long learning

連續學習的概念大概是在2016年以後纔開始流行的,雖然今天的工業界中幾乎都是使用一個或多個模型對應一個任務,但是爲了讓機器更像人,讓機器能同時解決多個任務,同時把過去的知識運用到新的任務上,也是值得研究的課題。

方法

  • Regularization-based methods
  • Parameter isolation methods
我們要實踐的就是這種非常基礎,但是又非常有效的方法,Elastic Weight Consolidation(EWC),基於正則化的模型長期學習方法。 這裏我們實現一個能同時實現MNIST和USPS識別的長期學習模型。

模型

使用Relu和線性層組成的全連接網絡實現多個MNIST圖像分類任務。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torch.utils.data.sampler as sampler
import torchvision
from torchvision import datasets, transforms

import numpy as np
import os
import random
from copy import deepcopy
import json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Model(nn.Module):

    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(28*28, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, 128)
        self.fc5 = nn.Linear(128, 128)
        self.fc6 = nn.Linear(128, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.relu(x)
        x = self.fc4(x)
        x = self.relu(x)
        x = self.fc5(x)
        x = self.relu(x)
        x = self.fc6(x)
        return x

EWC

EWC的基礎思想是把已經訓練好的模型中的比較重要的參數用正則化項保護起來,讓它們變得不那麼容易被更新,從而舊的知識就不會被完全洗掉。把參數的損失函數寫出則是下面的公式
LB=L(θ)+iλ2Fi(θiθA,i)2\mathcal{L}_B = \mathcal{L}(\theta) + \sum_{i} \frac{\lambda}{2} F_i (\theta_{i} - \theta_{A,i}^{*})^2
我們在基礎的損失函數上增加一個正則化項,每個參數受到自己在上一個任務訓練完畢後,最後的參數值的約束。lambda是一般的係數,其中的F評估參數重要程度。
如果一個參數很重要,它應該有很大的F。我們有很多種方法定義F,比如我們可以計算參數的對損失函數的二階偏導,也可以用更簡單的一階方法定義,比如下面的
F=[log(p(ynxn,θA)log(p(ynxn,θA)T] F = [ \nabla \log(p(y_n | x_n, \theta_{A}^{*}) \nabla \log(p(y_n | x_n, \theta_{A}^{*})^T ]
也就是我們計算給定數據集x和終末狀態參數,計算分類正確的先驗概率的梯度,再對梯度求內積。這將是一個還不錯的重要度估計量,實際上這個公式的推導還是有點東西的,是fisher信息矩陣的簡單近似,這裏不細講,只知道如何實現即可。
既然知道了怎麼計算重要度,用Pytorch實現就只需要把整個訓練集丟進模型,softmax計算概率然後取對應位置的正確概率,對所有數據點求平均,反向傳播計算梯度,對每個模型參數的梯度計算內積,就結束啦。
至於正則化就更簡單了,每次梯度下降更新參數後,再做一個L2懲罰就行了。或者直接把L2的計算和loss func加在一起再梯度下降也是可以的。

數據集使用

我們做兩個任務,一個是MNIST手寫數字識別,另一個是USPS手寫數字識別,在此之前我們需要把USPS的16x16轉爲28x28.

MNIST_transform = transforms.Compose([
    transforms.ToTensor(),
])
USPS_transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
])
mnist = torchvision.datasets.MNIST(
    root='C:/Users/Administrator/DL',
    train=True,                                     
    transform = MNIST_transform
)

usps = datasets.USPS(
    root = 'C:/Users/Administrator/DL',
    transform=USPS_transform,
    train = True,
    download=True
)

batch_size = 100

mnist_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size, 
                                          shuffle=True)
usps_loader = torch.utils.data.DataLoader(dataset=usps,
                                          batch_size=batch_size, 
                                          shuffle=True)

測試結果

我們需要打印出EWC的成效,在使用EWC後,即使學習了usps也不會讓mnist的正確率下降太多。首先進行普通的訓練,看一看如果不進行fine tune,在學習了一個任務後,模型還能不能在另一個任務上表現的好。

def normal_train(model, optimizer, loader, summary_epochs):
    model.train()
    model.zero_grad()
    loss_func = nn.CrossEntropyLoss()
    losses = []
    loss = 0.0
    for epoch in range(summary_epochs):
        for step, (imgs, labels) in enumerate(loader):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            ce_loss = loss_func(outputs, labels)

            optimizer.zero_grad()
            ce_loss.backward()
            optimizer.step()

            loss += ce_loss.item()
            if (step + 1) % 20 == 0:
                loss = loss / 20
                print ("\r", "Epoch {}, step {}, loss: {:.3f}      ".format(epoch + 1,step+1,loss), end=" ")
                losses.append(loss)
                loss = 0.0
                
    return losses

def verify(model, loader):
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in loader:
            images = images.reshape(-1, 28*28).to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print('Accuracy of the network on given dataset: {} %'.format(100 * correct / total))


model = Model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

normal_train(model,optimizer,mnist_loader,10)
verify(model,mnist_loader)
normal_train(model,optimizer,usps_loader,10)
verify(model,usps_loader)
verify(model,mnist_loader)

 Epoch 10, step 600, loss: 0.018       
 Accuracy of the network on given dataset: 99.51333333333334 %
 Epoch 10, step 60, loss: 0.018       
 Accuracy of the network on given dataset: 99.21821423673022 %
Accuracy of the network on given dataset: 77.80166666666666 %

實現EWC的train的時候,先計算模型的參數重要度矩陣,用重要度做參數,在原loss上增加L2正則。

def ewc_train(model, optimizer, previous_loader, loader,  summary_epochs, lambda_ewc):
    # 計算重要度矩陣
    params = {n: p for n, p in model.named_parameters() if p.requires_grad}# 模型的所有參數
    
    _means = {} # 初始化要把參數限制在的參數域
    for n, p in params.items():
        _means[n] = p.clone().detach()
    
    precision_matrices = {} #重要度
    for n, p in params.items():
        precision_matrices[n] = p.clone().detach().fill_(0) #取zeros_like

    model.eval()
    for data, labels in previous_loader:
        model.zero_grad()
        data, labels = data.to(device),labels.to(device)
        output = model(data)
        ############ 核心代碼 #############
        loss = F.nll_loss(F.log_softmax(output, dim=1), labels)
        # 計算labels對應的(正確分類的)對數概率,並把它作爲loss func衡量參數重要度        
        loss.backward()  # 反向傳播計算導數
        for n, p in model.named_parameters():                         
            precision_matrices[n].data += p.grad.data ** 2 / len(previous_loader)
        ########### 計算對數概率的導數,然後反向傳播計算梯度,以梯度的平方作爲重要度 ########

    model.train()
    model.zero_grad()
    loss_func = nn.CrossEntropyLoss()
    losses = []
    loss = 0.0
    for epoch in range(summary_epochs):
        for step, (imgs, labels) in enumerate(loader):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            ce_loss = loss_func(outputs, labels)
            total_loss = ce_loss
            # 額外計算EWC的L2 loss
            ewc_loss = 0
            for n, p in model.named_parameters():
                _loss = precision_matrices[n] * (p - _means[n]) ** 2
                ewc_loss += _loss.sum()
            total_loss += lambda_ewc * ewc_loss
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            loss += total_loss.item()
            if (step + 1) % 20 == 0:
                loss = loss / 20
                print ("\r", "Epoch {}, step {}, loss: {:.3f}      ".format(epoch + 1,step+1,loss), end=" ")
                losses.append(loss)
                loss = 0.0
                
    return losses


model = Model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

normal_train(model,optimizer,mnist_loader,10)
verify(model,mnist_loader)
ewc_train(model,optimizer,mnist_loader,usps_loader,10,350)
verify(model,usps_loader)
verify(model,mnist_loader)

Epoch 10, step 600, loss: 0.039       
Accuracy of the network on given dataset: 99.67166666666667 %
Epoch 10, step 60, loss: 0.017       
Accuracy of the network on given dataset: 99.71197366616376 %
Accuracy of the network on given dataset: 91.02666666666667 %

可以看見,使用EWC前後,USPS的訓練正確率並沒有下降,但是卻使用正則化保住了前面的MNIST任務的正確率。

小結

Life Long Learning是一種非常新奇的技術,和工業界中在使用被用來解決問題的強化學習和遷移學習技術不太一樣;雖然這種連續學習技術也是爲了把多個任務融會貫通,但是LLL更偏重於如何讓模型學會多種知識且不忘記之前的知識,用一個模型解決多種問題。
雖然這種技術現在看來是一種很沒用的技術,但是爲了讓機器更像人,實現強人工智能以及Learn to learn的夢想,這種長期學習的技術又是必不可少的。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章