可视化CNN

深度学习的可解释性一直是比较差的,因为神经网络不像传统算法,可以明确的被解释机器为什么下某种判断,做某种分析。比如决策树,就可以直接告诉你因为XX特征是XX,所以我们把它归为某类。又或者SVM,因为训练集中的X1和X2是支持向量,所以在它下面的样本是正,上面的样本是负。神经网络是用大量的矩阵运算和激活函数实现一些不可思议的逻辑,我们无法看一个矩阵就猜出网络在干什么。因此我们可能需要特地设计一些,用于解释网络的算法。
以CNN为例,我们可以用可视化的方法把CNN学到了什么展示出来。

训练CNN模型

总之在可视化之前先训练一个模型。我们用一个多个卷积层,带有dropout的CNN在CIFAR-10上训练。

import os
import sys
import argparse
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import Dataset
import torch.utils.data as Data
import torchvision.transforms as transforms
import torchvision
from skimage.segmentation import slic
from pdb import set_trace

EPOCH=10
BATCH_SIZE=50
LR=0.001

# 踩了坑,这里如果要用BatchNorm2d的话一定要对数据做归一化
# 不然训练好的模型会随着batch_size大小的变动正确率下降

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_data=torchvision.datasets.CIFAR10(
    root='C:/Users/Administrator/DL/cifar10',
    train=True,
    transform=transform
)
train_loader = Data.DataLoader(
    dataset=train_data,
    batch_size=BATCH_SIZE,
    shuffle=True
)
test_data=torchvision.datasets.CIFAR10(
    root='C:/Users/Administrator/DL/cifar10',
    train=False,
    transform=transform
)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE,
                                         shuffle=False)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding = 1)
        self.conv2 = nn.Conv2d(64, 64, 3, padding = 1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding = 1)
        self.conv4 = nn.Conv2d(128, 128, 3, padding = 1)
        self.conv5 = nn.Conv2d(128, 256, 3, padding = 1)
        self.conv6 = nn.Conv2d(256, 256, 3, padding = 1)
        self.maxpool = nn.MaxPool2d(2, 2)
        self.avgpool = nn.AvgPool2d(2, 2)
        self.globalavgpool = nn.AvgPool2d(8, 8)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(256)
        self.dropout50 = nn.Dropout(0.1)
        self.dropout10 = nn.Dropout(0.1)
        self.fc = nn.Linear(256, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        #x = self.bn1(F.relu(self.conv1(x)))
        x = F.relu(self.conv2(x))
        #x = self.bn1(F.relu(self.conv2(x)))
        x = self.maxpool(x)
        x = self.dropout10(x)
        x = F.relu(self.conv3(x))
        #x = self.bn2(F.relu(self.conv3(x)))
        x = F.relu(self.conv4(x))
        #x = self.bn2(F.relu(self.conv4(x)))
        x = self.avgpool(x)
        x = self.dropout10(x)
        x = F.relu(self.conv5(x))
        #x = self.bn3(F.relu(self.conv5(x)))
        x = F.relu(self.conv6(x))
        #x = self.bn3(F.relu(self.conv6(x)))
        x = self.globalavgpool(x)
        x = self.dropout50(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding = 1)
        self.conv2 = nn.Conv2d(64, 64, 3, padding = 1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding = 1)
        self.conv4 = nn.Conv2d(128, 128, 3, padding = 1)
        self.conv5 = nn.Conv2d(128, 256, 3, padding = 1)
        self.conv6 = nn.Conv2d(256, 256, 3, padding = 1)
        self.maxpool = nn.MaxPool2d(2, 2)
        self.avgpool = nn.AvgPool2d(2, 2)
        self.globalavgpool = nn.AvgPool2d(8, 8)
        self.dropout50 = nn.Dropout(0.1)
        self.dropout10 = nn.Dropout(0.1)
        self.fc = nn.Linear(256, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.maxpool(x)
        x = self.dropout10(x)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.avgpool(x)
        x = self.dropout10(x)
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = self.globalavgpool(x)
        x = self.dropout50(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

import torch.optim as optim

net = Net()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)

for epoch in range(10):

    running_loss = 0.
    batch_size = 100
    
    for i, data in enumerate(train_loader):
        
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        if (i+1)%100==0:
            print('[%d, %5d] loss: %.4f' %(epoch + 1, (i+1)*BATCH_SIZE, loss.item()))
        

print('Finished Training')

torch.save(net, 'cifar10.pkl')

这个模型能在测试集上拿到82的正确率,差不多可以用了。

net = torch.load('cifar10.pkl')
net.eval()

correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

Accuracy of the network on the 10000 test images: 82 %

Saliency map

我们把一张图片丢进 model,forward 后与 label 计算出 loss。因此与 loss 相关的有: image model parameter label 通常的情况下,我们想要改变 model parameter 来 fit image 和 label。因此 loss 在计算 backward 时我们只在乎 loss 对 model parameter 的偏微分值。但数学上 image 本身也是 continuous tensor,我们可以计算 loss 对 image 的偏微分值。这个偏微分值代表「在 model parameter 和 label 都固定下,稍微改变 image 的某个 pixel value 会对 loss 产生什么变化」。人们习惯把这个变化的剧烈程度解读成该 pixel 的重要性 (每个 pixel 都有自己的偏微分值)。因此把同一张图中,loss 对每个 pixel 的偏微分值画出来,就可以看出该图中哪些位置是 model 在判断时的重要依据。

实作上非常简单,过去我们都是 forward 后算出 loss,然后进行 backward。而这个 backward,pytorch 预设是计算 loss 对 model parameter 的偏微分值,因此我们只需要用一行 code 额外告知 pytorch,image 也是要算偏微分的对象之一。

def normalize(image):
    return (image - image.min()) / (image.max() - image.min())

def compute_saliency_maps(x, y, model):
    model.eval()
    x = x.cuda()
    # **这里需要计算图片的偏导
    x.requires_grad_()

    y_pred = model(x)
    loss_func = torch.nn.CrossEntropyLoss()
    loss = loss_func(y_pred, y.cuda())
    loss.backward()

    saliencies = x.grad.abs().detach().cpu()
    saliencies = torch.stack([normalize(item) for item in saliencies])
    return saliencies

img_indices = [2,22,222,2222]
images = torch.tensor(train_data.train_data[img_indices]).float().permute(0, 3, 1, 2)/255.
labels = torch.tensor(train_data.train_labels)[img_indices]
saliencies = compute_saliency_maps(images, labels, net)

# 绘制saliency map
fig, axs = plt.subplots(2, len(img_indices), figsize=(15, 8))
for row, target in enumerate([images, saliencies]):
    for column, img in enumerate(target):
        axs[row][column].imshow(img.permute(1, 2, 0).numpy())

plt.show()

在这里插入图片描述
MAP会大致告诉我们,模型是因为看见了哪些东西,才产生这样的分类。

Filter explaination

可视化卷积核的过程并不复杂,对某个特定的卷积核,寻找让它的激活最大化的图片;找到的图片上就会显示出卷积核提取出的特征。
在使用Pytorch实现时,我们要把其中特定的某层某个卷积核的输出拿出,然后对它单独计算loss并用这个loss修正图片,直到迭代停止条件。

def normalize(image):
    return (image - image.min()) / (image.max() - image.min())

layer_activations = None
def filter_explaination(x, model, cnnid, filterid, iteration=100, lr=1):
    model.eval()

    def hook(model, input, output):
        global layer_activations
        layer_activations = output
        
    hook_handle = eval("net.conv"+str(cnnid)).register_forward_hook(hook)
    # 在执行到cnnid层时,先执行hook函数,再进入下一层

    # 前向传播
    model(x.cuda())
    # 获取卷积核输出
    filter_activations = layer_activations[:, filterid, :, :].detach().cpu()
    
    x = x.cuda()
    # 从原始图片开始迭代
    x.requires_grad_()
    # 计算目标函数关于x的偏导
    optimizer = Adam([x], lr=lr)
    # 用adam优化
    
    for iter in range(iteration):
        optimizer.zero_grad()
        model(x)
        objective = -layer_activations[:, filterid, :, :].sum()
        # 目标函数
        objective.backward()
        # 求导
        optimizer.step()
        # 更新
    filter_visualization = x.detach().cpu().squeeze()[0]
    hook_handle.remove()
    # 清除hook
    
    return filter_activations, filter_visualization

img_indices = [2,22,222,2222]
images = torch.tensor(train_data.train_data[img_indices]).float().permute(0, 3, 1, 2)/255.
labels = torch.tensor(train_data.train_labels)[img_indices]
filter_activations, filter_visualization = filter_explaination(
    images, net, cnnid=3 , filterid=3, iteration=100, lr=0.1)

plt.imshow(normalize(filter_visualization.permute(1, 2, 0)))
plt.show()
fig, axs = plt.subplots(2, len(img_indices), figsize=(15, 8))
for i, img in enumerate(images):
    axs[0][i].imshow(img.permute(1, 2, 0))
for i, img in enumerate(filter_activations):
    axs[1][i].imshow(normalize(img))
plt.show()

在这里插入图片描述
这里可视化了第三级的卷积层中的一个卷积核,可以看出这个卷积核会对这种交错的十字条纹感兴趣。下面的激活状态图里也可以看出,有横线和竖线的特征将会被高亮,显示出黄色。

LIME

一种基于采样和线性回归的技术…有机会再补充吧

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章