可視化CNN

深度學習的可解釋性一直是比較差的,因爲神經網絡不像傳統算法,可以明確的被解釋機器爲什麼下某種判斷,做某種分析。比如決策樹,就可以直接告訴你因爲XX特徵是XX,所以我們把它歸爲某類。又或者SVM,因爲訓練集中的X1和X2是支持向量,所以在它下面的樣本是正,上面的樣本是負。神經網絡是用大量的矩陣運算和激活函數實現一些不可思議的邏輯,我們無法看一個矩陣就猜出網絡在幹什麼。因此我們可能需要特地設計一些,用於解釋網絡的算法。
以CNN爲例,我們可以用可視化的方法把CNN學到了什麼展示出來。

訓練CNN模型

總之在可視化之前先訓練一個模型。我們用一個多個卷積層,帶有dropout的CNN在CIFAR-10上訓練。

import os
import sys
import argparse
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import Dataset
import torch.utils.data as Data
import torchvision.transforms as transforms
import torchvision
from skimage.segmentation import slic
from pdb import set_trace

EPOCH=10
BATCH_SIZE=50
LR=0.001

# 踩了坑,這裏如果要用BatchNorm2d的話一定要對數據做歸一化
# 不然訓練好的模型會隨着batch_size大小的變動正確率下降

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_data=torchvision.datasets.CIFAR10(
    root='C:/Users/Administrator/DL/cifar10',
    train=True,
    transform=transform
)
train_loader = Data.DataLoader(
    dataset=train_data,
    batch_size=BATCH_SIZE,
    shuffle=True
)
test_data=torchvision.datasets.CIFAR10(
    root='C:/Users/Administrator/DL/cifar10',
    train=False,
    transform=transform
)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE,
                                         shuffle=False)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding = 1)
        self.conv2 = nn.Conv2d(64, 64, 3, padding = 1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding = 1)
        self.conv4 = nn.Conv2d(128, 128, 3, padding = 1)
        self.conv5 = nn.Conv2d(128, 256, 3, padding = 1)
        self.conv6 = nn.Conv2d(256, 256, 3, padding = 1)
        self.maxpool = nn.MaxPool2d(2, 2)
        self.avgpool = nn.AvgPool2d(2, 2)
        self.globalavgpool = nn.AvgPool2d(8, 8)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(256)
        self.dropout50 = nn.Dropout(0.1)
        self.dropout10 = nn.Dropout(0.1)
        self.fc = nn.Linear(256, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        #x = self.bn1(F.relu(self.conv1(x)))
        x = F.relu(self.conv2(x))
        #x = self.bn1(F.relu(self.conv2(x)))
        x = self.maxpool(x)
        x = self.dropout10(x)
        x = F.relu(self.conv3(x))
        #x = self.bn2(F.relu(self.conv3(x)))
        x = F.relu(self.conv4(x))
        #x = self.bn2(F.relu(self.conv4(x)))
        x = self.avgpool(x)
        x = self.dropout10(x)
        x = F.relu(self.conv5(x))
        #x = self.bn3(F.relu(self.conv5(x)))
        x = F.relu(self.conv6(x))
        #x = self.bn3(F.relu(self.conv6(x)))
        x = self.globalavgpool(x)
        x = self.dropout50(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding = 1)
        self.conv2 = nn.Conv2d(64, 64, 3, padding = 1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding = 1)
        self.conv4 = nn.Conv2d(128, 128, 3, padding = 1)
        self.conv5 = nn.Conv2d(128, 256, 3, padding = 1)
        self.conv6 = nn.Conv2d(256, 256, 3, padding = 1)
        self.maxpool = nn.MaxPool2d(2, 2)
        self.avgpool = nn.AvgPool2d(2, 2)
        self.globalavgpool = nn.AvgPool2d(8, 8)
        self.dropout50 = nn.Dropout(0.1)
        self.dropout10 = nn.Dropout(0.1)
        self.fc = nn.Linear(256, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.maxpool(x)
        x = self.dropout10(x)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.avgpool(x)
        x = self.dropout10(x)
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = self.globalavgpool(x)
        x = self.dropout50(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

import torch.optim as optim

net = Net()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)

for epoch in range(10):

    running_loss = 0.
    batch_size = 100
    
    for i, data in enumerate(train_loader):
        
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        if (i+1)%100==0:
            print('[%d, %5d] loss: %.4f' %(epoch + 1, (i+1)*BATCH_SIZE, loss.item()))
        

print('Finished Training')

torch.save(net, 'cifar10.pkl')

這個模型能在測試集上拿到82的正確率,差不多可以用了。

net = torch.load('cifar10.pkl')
net.eval()

correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

Accuracy of the network on the 10000 test images: 82 %

Saliency map

我們把一張圖片丟進 model,forward 後與 label 計算出 loss。因此與 loss 相關的有: image model parameter label 通常的情況下,我們想要改變 model parameter 來 fit image 和 label。因此 loss 在計算 backward 時我們只在乎 loss 對 model parameter 的偏微分值。但數學上 image 本身也是 continuous tensor,我們可以計算 loss 對 image 的偏微分值。這個偏微分值代表「在 model parameter 和 label 都固定下,稍微改變 image 的某個 pixel value 會對 loss 產生什麼變化」。人們習慣把這個變化的劇烈程度解讀成該 pixel 的重要性 (每個 pixel 都有自己的偏微分值)。因此把同一張圖中,loss 對每個 pixel 的偏微分值畫出來,就可以看出該圖中哪些位置是 model 在判斷時的重要依據。

實作上非常簡單,過去我們都是 forward 後算出 loss,然後進行 backward。而這個 backward,pytorch 預設是計算 loss 對 model parameter 的偏微分值,因此我們只需要用一行 code 額外告知 pytorch,image 也是要算偏微分的對象之一。

def normalize(image):
    return (image - image.min()) / (image.max() - image.min())

def compute_saliency_maps(x, y, model):
    model.eval()
    x = x.cuda()
    # **這裏需要計算圖片的偏導
    x.requires_grad_()

    y_pred = model(x)
    loss_func = torch.nn.CrossEntropyLoss()
    loss = loss_func(y_pred, y.cuda())
    loss.backward()

    saliencies = x.grad.abs().detach().cpu()
    saliencies = torch.stack([normalize(item) for item in saliencies])
    return saliencies

img_indices = [2,22,222,2222]
images = torch.tensor(train_data.train_data[img_indices]).float().permute(0, 3, 1, 2)/255.
labels = torch.tensor(train_data.train_labels)[img_indices]
saliencies = compute_saliency_maps(images, labels, net)

# 繪製saliency map
fig, axs = plt.subplots(2, len(img_indices), figsize=(15, 8))
for row, target in enumerate([images, saliencies]):
    for column, img in enumerate(target):
        axs[row][column].imshow(img.permute(1, 2, 0).numpy())

plt.show()

在這裏插入圖片描述
MAP會大致告訴我們,模型是因爲看見了哪些東西,才產生這樣的分類。

Filter explaination

可視化卷積核的過程並不複雜,對某個特定的卷積核,尋找讓它的激活最大化的圖片;找到的圖片上就會顯示出卷積核提取出的特徵。
在使用Pytorch實現時,我們要把其中特定的某層某個卷積核的輸出拿出,然後對它單獨計算loss並用這個loss修正圖片,直到迭代停止條件。

def normalize(image):
    return (image - image.min()) / (image.max() - image.min())

layer_activations = None
def filter_explaination(x, model, cnnid, filterid, iteration=100, lr=1):
    model.eval()

    def hook(model, input, output):
        global layer_activations
        layer_activations = output
        
    hook_handle = eval("net.conv"+str(cnnid)).register_forward_hook(hook)
    # 在執行到cnnid層時,先執行hook函數,再進入下一層

    # 前向傳播
    model(x.cuda())
    # 獲取卷積核輸出
    filter_activations = layer_activations[:, filterid, :, :].detach().cpu()
    
    x = x.cuda()
    # 從原始圖片開始迭代
    x.requires_grad_()
    # 計算目標函數關於x的偏導
    optimizer = Adam([x], lr=lr)
    # 用adam優化
    
    for iter in range(iteration):
        optimizer.zero_grad()
        model(x)
        objective = -layer_activations[:, filterid, :, :].sum()
        # 目標函數
        objective.backward()
        # 求導
        optimizer.step()
        # 更新
    filter_visualization = x.detach().cpu().squeeze()[0]
    hook_handle.remove()
    # 清除hook
    
    return filter_activations, filter_visualization

img_indices = [2,22,222,2222]
images = torch.tensor(train_data.train_data[img_indices]).float().permute(0, 3, 1, 2)/255.
labels = torch.tensor(train_data.train_labels)[img_indices]
filter_activations, filter_visualization = filter_explaination(
    images, net, cnnid=3 , filterid=3, iteration=100, lr=0.1)

plt.imshow(normalize(filter_visualization.permute(1, 2, 0)))
plt.show()
fig, axs = plt.subplots(2, len(img_indices), figsize=(15, 8))
for i, img in enumerate(images):
    axs[0][i].imshow(img.permute(1, 2, 0))
for i, img in enumerate(filter_activations):
    axs[1][i].imshow(normalize(img))
plt.show()

在這裏插入圖片描述
這裏可視化了第三級的卷積層中的一個卷積核,可以看出這個卷積核會對這種交錯的十字條紋感興趣。下面的激活狀態圖裏也可以看出,有橫線和豎線的特徵將會被高亮,顯示出黃色。

LIME

一種基於採樣和線性迴歸的技術…有機會再補充吧

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章