pytorch模型中間層特徵的提取

參考pytorch論壇:How to extract features of an image from a trained model

定義一個特徵提取的類:

#中間特徵提取
class FeatureExtractor(nn.Module):
    def __init__(self, submodule, extracted_layers):
        super(FeatureExtractor,self).__init__()
        self.submodule = submodule
        self.extracted_layers= extracted_layers

    def forward(self, x):
        outputs = []
        for name, module in self.submodule._modules.items():
            if name is "fc": x = x.view(x.size(0), -1)
            x = module(x)
            print(name)
            if name in self.extracted_layers:
                outputs.append(x)
        return outputs
#輸入數據
test_loader=DataLoader(test_dataset,batch_size=1)
img,label=iter(test_loader).next()
img, label = Variable(img, volatile=True), Variable(label, volatile=True)
#特徵輸出
myresnet=resnet18(pretrained=False)
myresnet.load_state_dict(torch.load('cafir_resnet18_1.pkl')) 
exact_list=["conv1","layer1","avgpool"]
myexactor=FeatureExtractor(myresnet,exact_list)
x=myexactor(img)
#特徵輸出可視化
import matplotlib.pyplot as plt
for i in range(64):
    ax = plt.subplot(8, 8, i + 1)
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
    plt.show()




發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章