pytorch從預訓練模型提取圖像特徵計算featureless

1.背景描述

想借助一個預訓練好的網絡(非集成好的)計算feature-loss,預訓練網絡地址:表情識別net

2具體實操

2.1 加載模型

作者已經給出了預訓練好的模型參數和模型代碼,首先我們要把模型load進來:

    from Expression.VGG import VGG
    model = VGG('VGG19')
    #check_pth 從網站上download下來PrivateTest_model.t7
    checkpoint = torch.load(check_pth)
    model.load_state_dict(checkpoint['net'])
    model.cuda()
    model.eval()

我們可以看一下該模型的結構

print(torch.nn.Sequential(*list(model.children())[:]))

結果爲:

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (16): ReLU(inplace)
    (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (19): ReLU(inplace)
    (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (22): ReLU(inplace)
    (23): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (25): ReLU(inplace)
    (26): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (27): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (29): ReLU(inplace)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (32): ReLU(inplace)
    (33): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (34): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (35): ReLU(inplace)
    (36): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (37): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (38): ReLU(inplace)
    (39): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (42): ReLU(inplace)
    (43): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (44): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (45): ReLU(inplace)
    (46): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (47): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (48): ReLU(inplace)
    (49): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (50): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (51): ReLU(inplace)
    (52): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (53): AvgPool2d(kernel_size=1, stride=1, padding=0)
  )
  (1): Linear(in_features=512, out_features=7, bias=True)
)

2.2提取特徵

class FeatureExtractor(nn.Module):
    def __init__(self, model, feature_layer=50, device=torch.device('cpu')):
        super(FeatureExtractor, self).__init__()
        self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
  
        for k, v in self.features.named_parameters():
            v.requires_grad = False

    def forward(self, x):
        # Assume input range is [0, 1]
        output = self.features(x)
        return output

3.較完整代碼

def define_E(opt, check_pth='Expression/FER2013_VGG19/PrivateTest_model.t7'):
    from Expression.VGG import VGG, FeatureExtractor
    netE = VGG('VGG19')
    checkpoint = torch.load(check_pth)
    netE.load_state_dict(checkpoint['net'])
    netE.cuda()
    netE.eval()
    extract_E = FeatureExtractor(model=netE, feature_layer=50)
    return extract_E.eval()
def loss(train_opt, device):
	if train_opt.expression_weight > 0:
		l_exp_type = train_opt.expression_criterion
		if l_exp_type == 'l1':
			cri_exp = nn.L1Loss().to(device)
		elif l_exp_type == 'l2':
			cri_exp = nn.MSELoss().to(device)
		else:
			raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_exp_type))
		l_exp_w = train_opt.expression_weight
     else:
		logger.info('Remove expression loss.')
		cri_exp = None
	if cri_exp:
		netE = define_E(train_opt).to(device)
		netE = DataParallel(netE)
	l_g_exp = 0
	real_exp = netE(real_img).detach()
	fake_exp = netE(fake_img)
	l_g_exp += l_exp_w * cri_exp(fake_exp, real_exp)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章