本文基於https://blog.csdn.net/GrayOnDream/article/details/99090247的博客進行了進一步的修改
因爲上述博客的網絡層順序是從network文件順序讀取class的,不適用於我的網絡(我的網絡是定義了很多基礎模塊然後拼接起來的)。因爲大多數人定義網絡的順序和真實運行的順序不太一樣,所以我在此基礎上做了修改
完整代碼如下,網絡是一個類似u-net的網絡
import os
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
import skimage.data
import skimage.io
import skimage.transform
import numpy as np
import matplotlib.pyplot as plt
import torchvision.models as models
from PIL import Image
import cv2
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layers):
super(FeatureExtractor, self).__init__()
self.submodule = submodule
self.extracted_layers = extracted_layers
def forward(self, x):
outputs = {}
# for name, module in self.submodule._modules.items():
# if "fc" in name:
# x = x.view(x.size(0), -1)
#
# x = module(x)
# print(name)
# if self.extracted_layers is None or name in self.extracted_layers and 'fc' not in name:
# outputs[name] = x
################修改成自己的網絡,直接在network.py中return你想輸出的層
x1,x2,x3,x4,x5,x6,up7,merge7,conv7,up8,merge8,conv8,up9,merge9,conv9,up10,merge10,conv10,up11,merge11,conv11,conv12,mask,x2_0 = self.submodule(x)
outputs["x1"] = x1
outputs["x2"] = x2
outputs["x3"] = x3
outputs["x4"] = x4
outputs["x5"] = x5
outputs["x6"] = x6
outputs["up7"] = up7
outputs["merge7"] = merge7
outputs["conv7"] = conv7
outputs["up8"] = up8
outputs["merge8"] = merge8
outputs["conv8"] = conv8
outputs["up9"] = up9
outputs["merge9"] = merge9
outputs["conv9"] = conv9
outputs["up10"] = up10
outputs["merge10"] = merge10
outputs["conv10"] = conv10
outputs["up11"] = up11
outputs["merge11"] = merge11
outputs["conv11"] = conv11
outputs["conv12"] = conv12
outputs["mask"] = mask
outputs["x2_0"] = x2_0
# return outputs
return outputs
def get_picture(pic_name, transform):
img = skimage.io.imread(pic_name)
img = skimage.transform.resize(img, (224, 224))
img = np.asarray(img, dtype=np.float32)
return transform(img)
def make_dirs(path):
if os.path.exists(path) is False:
os.makedirs(path)
def get_feature():
pic_dir = './input_images/1.jpg' #往網絡裏輸入一張圖片
transform = transforms.ToTensor()
img = get_picture(pic_dir, transform)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
# 插入維度
img = img.unsqueeze(0)
img = img.to(device)
net = torch.load('./models/1_70/19.pth')
net.to(device)
# exact_list = None
exact_list = ['conv1_block',""]
dst = './features' #保存的路徑
therd_size = 256 #有些圖太小,會放大到這個尺寸
myexactor = FeatureExtractor(net, exact_list)
outs = myexactor(img)
for k, v in outs.items():
features = v[0]
iter_range = features.shape[0]
for i in range(iter_range):
# plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
if 'fc' in k:
continue
feature = features.data.cpu().numpy()
feature_img = feature[i, :, :]
feature_img = np.asarray(feature_img * 255, dtype=np.uint8)
dst_path = os.path.join(dst, k)
make_dirs(dst_path)
feature_img = cv2.applyColorMap(feature_img, cv2.COLORMAP_JET)
if feature_img.shape[0] < therd_size:
tmp_file = os.path.join(dst_path, str(i) + '_' + str(therd_size) + '.png')
tmp_img = feature_img.copy()
tmp_img = cv2.resize(tmp_img, (therd_size, therd_size), interpolation=cv2.INTER_NEAREST)
cv2.imwrite(tmp_file, tmp_img)
dst_file = os.path.join(dst_path, str(i) + '.png')
cv2.imwrite(dst_file, feature_img)
if __name__ == '__main__':
get_feature()
最後的文件夾內容是這樣的:
可視化效果截圖