一分鐘教你在PyTorch跑模型的時候提取中間層查看圖片

首先導入需要的庫

import torch
import torch.nn as nn  # 網絡庫
import torch.nn.functional as F
import torch.optim as optim  # 優化器

import pandas as pd  # 數據處理

import torchvision
import torchvision.transforms as transforms   

torch.set_grad_enabled(True)
torch.set_printoptions(linewidth=120)

from torch.utils.data import DataLoader  # 數據庫
from IPython.display import display, clear_output  # 可以在notebook版本中顯示pandas

from torchvision import transforms
from PIL import Image
import cv2

然後搭建模型,我們這裏將cnn從第一層分開了,只留下了第一層的卷積層,也就是我們要查看第一層卷積後的樣子

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)  # 定義卷積層
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)  # 定義線性層
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)  # 定義輸出層
    
    # 開始搭建
    def forward(self, t):
        # (1) 輸入層,其實可以不寫,但寫成這樣更清楚
        t = t

        # (2) 第一個卷積層,卷積層後接激勵函數再接最大池化層
        t = self.conv1(t)
        
        # t = F.relu(t)
        # t = F.max_pool2d(t, kernel_size=2, stride=2)

        # # (3) 第二個卷積層,卷積層後接激勵函數再接最大池化層
        # t = self.conv2(t)
        # t = F.relu(t)
        # t = F.max_pool2d(t, kernel_size=2, stride=2)

        # # (4) 第一個線性層,進線性層之前先reshape成1*192(12*4*4)的維度,reshape函數中的-1代表n,根據
        # t = t.reshape(-1, 12 * 4 *4)
        # print('after reshaped:')
        # print(t.shape)
        # t = self.fc1(t)
        # t = F.relu(t)

        # # (5) 第二個線性層
        # t = self.fc2(t)
        # t = F.relu(t)

        # # (6) 輸出層
        # t = self.out(t)

        return t

# 獲取預測正確的個數
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

接下來加載數據

train_set = torchvision.datasets.FashionMNIST(
    root='./data/FashionMNIST'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=1)
batch = next(iter(train_loader))
images, labels = batch

創建模型,並將圖片輸入進去返回一個tensor,這個tensor就是我們得到的卷積後的六個圖片(因爲我們輸出通道爲6)

net = Network()
t = net(images)

把第一維度的bach_size給去掉

print(t.shape)
t = t.squeeze(0)
print(t.shape)  # 此處應該是[6, 24, 24]

這一步就是顯示出所有圖片,自己根據你的通道數改plt.subplot函數的前兩個參數,我這邊是6張圖片,所有我就繪製了2*3的圖片矩陣,效果如下圖:

from matplotlib.pyplot import imshow
import matplotlib.pyplot as plt 
import numpy as np
from PIL import Image
%matplotlib inline
#Tensor轉成PIL.Image重新顯示
for i in range(1, len(t) + 1):
    plt.subplot(2,3,i), plt.title('imge:oringe')
    new_img_PIL = transforms.ToPILImage()(t[i-1]).convert()
    imshow(np.asarray(new_img_PIL))

我是在notebook中執行的,如果在pycharm中可能需要刪掉inline這行
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章