Focal loss的pytorch版本實現

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision     
import numpy as np
import matplotlib.pyplot as plt 
import torch.nn.functional as F
torch.manual_seed(1)   
EPOCH = 3          
BATCH_SIZE = 50
LR = 0.001          
DOWNLOAD_MNIST = True 
train_data = torchvision.datasets.MNIST(
    root='./mnist/',   
    train=True, 
    transform=torchvision.transforms.ToTensor(),                                                     
    download=DOWNLOAD_MNIST,          
)
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), requires_grad=True).type(torch.FloatTensor)[:2000]/255 
test_y = test_data.test_labels[:2000]
class Net(nn.Module):
    #定義Net的初始化函數,這個函數定義了該神經網絡的基本結構
    def __init__(self):
        super(Net, self).__init__() #複製並使用Net的父類的初始化方法,
                                    #即先運行nn.Module的初始化函數
        #self.conv1 = nn.Conv2d(1, 6, 5) # 定義conv1函數的是圖像卷積函數:
                                        #輸入爲圖像(1個頻道,即灰度圖),輸出爲 6張特徵圖, 卷積核爲5x5正方形
        #self.conv2 = nn.Conv2d(6, 16, 3)# 定義conv2函數的是圖像卷積函數:
                                        #輸入爲6張特徵圖,輸出爲16張特徵圖, 卷積核爲5x5正方形
        self.fc1   = nn.Linear(28*28, 120) # 定義fc1(fullconnect)全連接函數1爲線性函數:
                                            #y = Wx + b,並將16*5*5個節點連接到120個節點上。
        self.fc2   = nn.Linear(120, 84)#定義fc2(fullconnect)全連接函數2爲線性函數:
                                       #y = Wx + b,並將120個節點連接到84個節點上。
        self.fc3   = nn.Linear(84, 10)#定義fc3(fullconnect)全連接函數3爲線性函數:
                                      #y = Wx + b,並將84個節點連接到10個節點上。

    #定義該神經網絡的向前傳播函數,該函數必須定義,一旦定義成功,向後傳播函數也會自動生成(autograd)
    def forward(self, x):
        #x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) #輸入x經過卷積conv1之後,經過激活函數ReLU,
                                                        #使用2x2的窗口進行最大池化Max pooling,然後更新到x。
        #x = F.max_pool2d(F.relu(self.conv2(x)), 2) #輸入x經過卷積conv2之後,經過激活函數ReLU,
                                                   #使用2x2的窗口進行最大池化Max pooling,然後更新到x。
        x = x.view(-1, self.num_flat_features(x)) #view函數將張量x變形成一維的向量形式,
                                                  #總特徵數並不改變,爲接下來的全連接作準備。
        '''
        a=self.num_flat_features(x)
        x=x.view(-1,a) #view函數將張量x變形成一維的向量形式
        '''        
        x = F.relu(self.fc1(x)) #輸入x經過全連接1,再經過ReLU激活函數,然後更新x
        x = F.relu(self.fc2(x)) #輸入x經過全連接2,再經過ReLU激活函數,然後更新x
        x = self.fc3(x) #輸入x經過全連接3,然後更新x
        return x

    #使用num_flat_features函數計算張量x的總特徵量(把每個數字都看出是一個特徵,即特徵總量),
    #比如x是4*2*2的張量,那麼它的特徵總量就是16。
    def num_flat_features(self, x):
        size = x.size()[1:] # 這裏爲什麼要使用[1:],是因爲pytorch只接受批輸入,也就是說一次性輸入好幾張圖片,
                            #從第二維開始去,那麼輸入數據張量的維度自然上升到了4維。【1:】讓我們把注意力放在後3維上面
        num_features = 1
        for s in size:
            num_features *= s
        return num_features
cnn = Net()
print(cnn)
# class FocalLoss(torch.nn.Module):
#     def __init__(self, gamma=2):
#         super().__init__()
#         self.gamma = gamma
# 
#     def forward(self, log_pred_prob_onehot, target):
#         pred_prob_oh = torch.exp(log_pred_prob_onehot)
#         pt = Variable(pred_prob_oh.data.gather(1, target.data.view(-1, 1)), requires_grad=True)
#         modulator = (1 - pt) ** self.gamma
#         mce = modulator * (-torch.log(pt))
# 
#         return mce.mean()
 
from torch.autograd import Variable
 
class FocalLoss(nn.Module):
    r"""
        This criterion is a implemenation of Focal Loss, which is proposed in 
        Focal Loss for Dense Object Detection.
 
            Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
 
        The losses are averaged across observations for each minibatch.
 
        Args:
            alpha(1D Tensor, Variable) : the scalar factor for this criterion
            gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), 
                                   putting more focus on hard, misclassified examples
            size_average(bool): By default, the losses are averaged over observations for each minibatch.
                                However, if the field size_average is set to False, the losses are
                                instead summed for each minibatch.
 
 
    """
    def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        if alpha is None:
            self.alpha = Variable(torch.ones(class_num, 1))
        else:
            if isinstance(alpha, Variable):
                self.alpha = alpha
            else:
                self.alpha = Variable(alpha)
        self.gamma = gamma
        self.class_num = class_num
        self.size_average = size_average
 
    def forward(self, inputs, targets):
        N = inputs.size(0)
        C = inputs.size(1)
        P = F.softmax(inputs,dim=1)
 
        class_mask = inputs.data.new(N, C).fill_(0)
        class_mask = Variable(class_mask)
        ids = targets.view(-1, 1)
        class_mask.scatter_(1, ids.data, 1.)
        #print(class_mask)
 
 
        if inputs.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()
        alpha = self.alpha[ids.data.view(-1)]
 
        probs = (P*class_mask).sum(1).view(-1,1)
 
        log_p = probs.log()
        #print('probs size= {}'.format(probs.size()))
        #print(probs)
 
        batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p 
        #print('-----bacth_loss------')
        #print(batch_loss)
 
 
        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
        return loss

 
loss_func=FocalLoss(10)
i=0
j=0
k=0
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) 
#loss_func = nn.CrossEntropyLoss() 
loss_l= Variable(torch.rand(3600)).data.numpy()
true_l= Variable(torch.rand(3600)).data.numpy()
for epoch in range(EPOCH):
    for step, (x, y) in enumerate(train_loader):          
        b_x = Variable(x)  
        b_y = Variable(y)    
        output = cnn(b_x)  
        k=k+1
        loss= loss_func(output, b_y) 
        loss_l[i]=loss
        i=i+1
        optimizer.zero_grad()            
        loss.backward()                  
        optimizer.step()
test_output = cnn(test_x[:200])
print(test_output)
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
print(pred_y, 'prediction number')
py=test_y[:200].numpy()
print(test_y[:200].numpy(), 'real number')
           
ss=len(loss_l)
kk=Variable(torch.linspace(1,ss,ss)).data.numpy()
plt.plot(kk,loss_l)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章