加入triplet loss的 reid Pytorch實現

把triplet loss加到reid的實現中了,工程目錄結構如下圖所示:

loss.py代碼如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

def euclidean_dist(x,y):
    m,n = x.size(0),y.size(0)
    xx = torch.pow(x,2).sum(1,keepdim=True).expand(m,n)
    yy = torch.pow(y,2).sum(dim=1,keepdim=True).expand(n,m).t()
    dist = xx + yy
    dist.addmm_(1,-2,x,y.t())
    dist = dist.clamp(min=1e-12).sqrt()
    return dist

def cosine_dist(x,y):
    bs1, bs2 = x.size(0),y.size(0)
    frac_up = torch.matmul(x,y.transpose(0,1))
    frac_down = (torch.sqrt( torch.pow(x,2).sum(dim=1) ).view(bs1,1).repeat(1,bs2)) * \
                (torch.sqrt( torch.pow(y,2).sum(dim=1).view(1,bs2).repeat(bs1,1) )  )
    cosine = frac_up/frac_down
    cos_d = 1 - cosine
    return cos_d

def _batch_hard(mat_distance,mat_similarity,indice=False):
    sorted_mat_distance, positive_indices = torch.sort(mat_distance + (-100000.0)*(1 - mat_similarity),dim=1, descending=True)
    hard_p = sorted_mat_distance[:,0]
    hard_p_indice = positive_indices[:,0]
    sorted_mat_distance, negative_indices = torch.sort( mat_distance + 100000.0 * mat_similarity,dim = 1,descending=False )
    hard_n = sorted_mat_distance[:,0]
    hard_n_indice = negative_indices[:,0]
    if(indice):
        return hard_p, hard_n, hard_p_indice, hard_n_indice
    return hard_p, hard_n

class TripletLoss(nn.Module):
    def __init__(self, margin=0.5, normalize_feature = True):
        super(TripletLoss,  self).__init__()
        self.margin = margin
        self.normalize_feature = normalize_feature
        self.margin_loss = nn.MarginRankingLoss(margin = margin)

    def forward(self, emb, label):
        if self.normalize_feature:
            emb = F.normalize(emb)
        #print('emb')
        #print(emb)
        mat_dist = euclidean_dist(emb, emb)
        #print('mat_dist')
        assert mat_dist.size(0) == mat_dist.size(1)
        N = mat_dist.size(0)
        mat_sim = label.expand(N,N).eq(label.expand(N,N).t()).float()
        #print(mat_dist)
        #print(mat_sim)
        dist_ap, dist_an = _batch_hard(mat_dist, mat_sim)
        assert dist_an.size(0) == dist_ap.size(0)
        y = torch.ones_like(dist_ap)
        loss = self.margin_loss(dist_an, dist_ap, y)

        prec = (dist_an.data > dist_ap.data).sum() * 1.0 / y.size(0)
        return loss, prec


# loss = nn.CrossEntropyLoss()
# an = torch.randn(4,3)
# y = torch.ones(4).long()
# print(an)
# print(y)
# l = loss(an,y)
# print(l)
# l.backward()
# print(an.grad)


裏面定義了triplet loss,這個東西說着挺簡單的,但是實現起來還是有些地方需要仔細琢磨考量的,建議對代碼的態度不是看看就好,而是敲擊,因爲一邊敲擊,一邊思考,一邊學習招數。之前看這個triplet loss的時候,總是在想,數據集按照pytorch的格式的話,怎麼確定誰是anchor,誰是positive, 誰是negative ,這把通過敲擊代碼全部明白了,只要將從dataloader中讀取的data中的label轉化一下就可以知道誰是positive,誰是negative了。

model.py的代碼如下:

import torch
import torch.nn as nn
from torchvision import models
from torch.nn import functional as F

class resnet_model(nn.Module):
    def __init__(self,cut_at_pooling=False, num_features=0, norm=False, dropout=0, num_classes=0 ):
        super(resnet_model,self).__init__()
        self.cut_at_pooling = cut_at_pooling
        resnet = models.resnet50(pretrained=False)
        resnet.load_state_dict(torch.load('./pretrain_model/resnet50.pth'))
        resnet.layer4[0].conv2.stride = (1,1)
        resnet.layer4[0].downsample[0].stride = (1,1)
        self.base = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4,
        )
        self.gap = nn.AdaptiveAvgPool2d(1)

        if not self.cut_at_pooling:
            self.num_features = num_features
            self.norm = norm
            self.dropout = dropout
            self.has_embedding = num_features > 0
            self.num_classes = num_classes

            out_planes = resnet.fc.in_features

            if self.has_embedding:
                self.feat = nn.Linear(out_planes, self.num_features)
                self.feat_bn = nn.BatchNorm1d(self.num_features)
                nn.init.kaiming_normal_(self.feat.weight,mode='fan_out')
                nn.init.constant_(self.feat.bias,0)
            else:
                self.num_features = out_planes
                self.feat_bn = nn.BatchNorm1d(self.num_features)
            self.feat_bn.bias.requires_grad_(False)
            if self.dropout > 0:
                self.drop = nn.Dropout(self.dropout)
            if self.num_classes > 0:
                self.classifier = nn.Linear(self.num_features,self.num_classes, bias=False)
                nn.init.normal_(self.classifier.weight, std=0.001)
        nn.init.constant_(self.feat_bn.weight, 1)
        nn.init.constant_(self.feat_bn.bias, 0)

    def forward(self,x,feature_withbn = False):
        x = self.base(x)

        x = self.gap(x)
        x = x.view(x.size(0), -1)

        if self.cut_at_pooling:
            return x

        if self.has_embedding:
            bn_x = self.feat_bn(self.feat(x))
        else:
            bn_x = self.feat_bn(x)

        if self.training is False:
            bn_x = F.normalize(bn_x)
            return bn_x

        if self.norm:
            bn_x = F.normalize(bn_x)
        elif self.has_embedding:
            bn_x = F.relu(bn_x)

        if self.dropout > 0:
            bn_x = self.drop(bn_x)

        if self.num_classes > 0:
            prob = self.classifier( bn_x )
        else:
            return x, bn_x

        if feature_withbn:
            return bn_x, prob
        return x, prob

   model.py好像也沒啥好說的,以resnet50爲backbone,另外加入了一個線性分類器。

reid.py的代碼如下:

import torch
import torch.nn as nn
from torchvision  import datasets, transforms
from model import resnet_model
from torch.optim import lr_scheduler
import loss

transform_list = [
    transforms.Resize((256,128), interpolation=3),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.Pad(10),
    transforms.RandomCrop((256,128)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
]

transform_compose = transforms.Compose(transform_list)

try_dataset1 = datasets.ImageFolder('./try_data1',transform_compose)
try_dataloader1 = torch.utils.data.DataLoader( try_dataset1, batch_size=32,shuffle=True )

try_data1_len = len(try_dataset1)
try_data1_class_name = try_dataset1.classes

net = resnet_model(num_classes=try_data1_len)
net.cuda()

params = []
for key, value in net.named_parameters():
    if not value.requires_grad:
        continue
    params += [ { 'params': [value], 'lr': 0.00035, 'weight_decay':5e-4} ]

optimizer = torch.optim.Adam( params )
exp_lr_scheduler = lr_scheduler.StepLR( optimizer, step_size=10, gamma=0.1 )

criterion_ce = nn.CrossEntropyLoss()
criterion_triple = loss.TripletLoss()

triplet_loss_list = []
pre_loss_list = []
loss_list = []
acc_list = []

for epoch in range(30):

    print("epoch: {} / 30" .format(epoch + 1))
    for data in try_dataloader1:
        input, labels = data
        input = input.cuda()
        labels = labels.cuda()

        features, pres = net(input)

        tri_loss, _ = criterion_triple( features, labels )
        ce_loss = criterion_ce(features, labels)

        loss = tri_loss + ce_loss
        triplet_loss_list.append(tri_loss.item())
        pre_loss_list.append(ce_loss.item())
        loss_list.append(loss.item())

        _, pid = torch.max(pres.data, dim = 1)
        acc = torch.sum( pid==labels.data )/pid.size(0)
        acc_list.append(acc.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    exp_lr_scheduler.step()
    all_acc = sum(acc_list)/(len(acc_list))
    all_triple_loss = sum(triplet_loss_list)/(len(triplet_loss_list))
    all_pre_loss = sum(pre_loss_list)/(len(pre_loss_list))
    all_loss = sum(loss_list)/(len(loss_list))

    print('accuracy: {:.4f}'.format(all_acc))
    print('triplet loss: {:.4f}:'.format(all_triple_loss))
    print('predict loss: {:.4f}'.format(all_pre_loss))
    print('loss : {:.4f}'.format(all_loss))

權當記錄一下吧,有問題儘管留言。爭取做到全網最簡單易懂,但是最完整的代碼展示。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章