pytorch特徵圖可視化

本文基於https://blog.csdn.net/GrayOnDream/article/details/99090247的博客進行了進一步的修改

因爲上述博客的網絡層順序是從network文件順序讀取class的,不適用於我的網絡(我的網絡是定義了很多基礎模塊然後拼接起來的)。因爲大多數人定義網絡的順序和真實運行的順序不太一樣,所以我在此基礎上做了修改

完整代碼如下,網絡是一個類似u-net的網絡

import os
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
import skimage.data
import skimage.io
import skimage.transform
import numpy as np
import matplotlib.pyplot as plt
import torchvision.models as models
from PIL import Image
import cv2


class FeatureExtractor(nn.Module):
    def __init__(self, submodule, extracted_layers):
        super(FeatureExtractor, self).__init__()
        self.submodule = submodule
        self.extracted_layers = extracted_layers

    def forward(self, x):
        outputs = {}
        # for name, module in self.submodule._modules.items():
        #     if "fc" in name:
        #         x = x.view(x.size(0), -1)
        #
        #     x = module(x)
        #     print(name)
        #     if self.extracted_layers is None or name in self.extracted_layers and 'fc' not in name:
        #         outputs[name] = x

################修改成自己的網絡,直接在network.py中return你想輸出的層

        x1,x2,x3,x4,x5,x6,up7,merge7,conv7,up8,merge8,conv8,up9,merge9,conv9,up10,merge10,conv10,up11,merge11,conv11,conv12,mask,x2_0 = self.submodule(x)
        outputs["x1"] = x1
        outputs["x2"] = x2
        outputs["x3"] = x3

        outputs["x4"] = x4
        outputs["x5"] = x5
        outputs["x6"] = x6

        outputs["up7"] = up7
        outputs["merge7"] = merge7
        outputs["conv7"] = conv7

        outputs["up8"] = up8
        outputs["merge8"] = merge8
        outputs["conv8"] = conv8

        outputs["up9"] = up9
        outputs["merge9"] = merge9
        outputs["conv9"] = conv9

        outputs["up10"] = up10
        outputs["merge10"] = merge10
        outputs["conv10"] = conv10

        outputs["up11"] = up11
        outputs["merge11"] = merge11
        outputs["conv11"] = conv11

        outputs["conv12"] = conv12
        outputs["mask"] = mask
        outputs["x2_0"] = x2_0



        # return outputs
        return outputs


def get_picture(pic_name, transform):
    img = skimage.io.imread(pic_name)
    img = skimage.transform.resize(img, (224, 224))
    img = np.asarray(img, dtype=np.float32)
    return transform(img)


def make_dirs(path):
    if os.path.exists(path) is False:
        os.makedirs(path)


def get_feature():
    pic_dir = './input_images/1.jpg' #往網絡裏輸入一張圖片
    transform = transforms.ToTensor()
    img = get_picture(pic_dir, transform)
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    # 插入維度
    img = img.unsqueeze(0)

    img = img.to(device)

    net = torch.load('./models/1_70/19.pth')
    net.to(device)
    # exact_list = None
    exact_list = ['conv1_block',""]
    dst = './features' #保存的路徑
    therd_size = 256 #有些圖太小,會放大到這個尺寸

    myexactor = FeatureExtractor(net, exact_list)
    outs = myexactor(img)
    for k, v in outs.items():
        features = v[0]
        iter_range = features.shape[0]
        for i in range(iter_range):
            # plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
            if 'fc' in k:
                continue

            feature = features.data.cpu().numpy()
            feature_img = feature[i, :, :]
            feature_img = np.asarray(feature_img * 255, dtype=np.uint8)

            dst_path = os.path.join(dst, k)

            make_dirs(dst_path)
            feature_img = cv2.applyColorMap(feature_img, cv2.COLORMAP_JET)
            if feature_img.shape[0] < therd_size:
                tmp_file = os.path.join(dst_path, str(i) + '_' + str(therd_size) + '.png')
                tmp_img = feature_img.copy()
                tmp_img = cv2.resize(tmp_img, (therd_size, therd_size), interpolation=cv2.INTER_NEAREST)
                cv2.imwrite(tmp_file, tmp_img)

            dst_file = os.path.join(dst_path, str(i) + '.png')
            cv2.imwrite(dst_file, feature_img)


if __name__ == '__main__':
    get_feature()

最後的文件夾內容是這樣的:

可視化效果截圖

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章