Python和PyTorch對比實現dropout函數及反向傳播

相關

原理及詳細解釋, 請參考 :
BrightLamp. dropout函數詳解及反向傳播中的梯度求導[EB/OL]. https://blog.csdn.net/oBrightLamp/article/details/84105097.

正文

import torch
import numpy as np


class Dropout:
    """
    http://arxiv.org/abs/1207.0580
    """

    def __init__(self, dropout_ratio=0.5):
        self.dropout_ratio = dropout_ratio
        self.train_flg = True
        self.mask = None

    def __call__(self, x, manual_mask=None, train_flg=True):
        if train_flg:
            if manual_mask is None:
                self.mask = np.random.rand(*x.shape) > self.dropout_ratio
            else:
                self.mask = manual_mask
            out = x * self.mask / (1.0 - self.dropout_ratio)
            return out
        else:
            return x

    def backward(self, d_loss):
        dx = d_loss * self.mask / (1.0 - self.dropout_ratio)
        return dx


np.set_printoptions(precision=6, suppress=True, linewidth=120)
np.random.seed(12)
torch.random.manual_seed(3)

x_numpy = np.random.random((3, 7))
x_tensor = torch.tensor(x_numpy, requires_grad=True)

drop_out_numpy = Dropout(dropout_ratio=0.45)
drop_out_tensor = torch.nn.Dropout(p=0.45)

print("\n----- 訓練階段 -----")
train_flag = True
drop_out_tensor.train()

out_tensor = drop_out_tensor(x_tensor)
mask = out_tensor > 0
mask = mask.data.numpy()
out_numpy = drop_out_numpy(x_numpy, mask, train_flg=train_flag)

print("train mask : \n", mask)
print("train x : \n", x_numpy)
print("numpy out : \n", out_numpy)
print("tensor out : \n", out_tensor.data.numpy())

print("\n----- 反向傳播 -----")
d_loss_numpy = np.random.random((3, 7))
d_loss_tensor = torch.tensor(d_loss_numpy, requires_grad=True)

dx_numpy = drop_out_numpy.backward(d_loss_numpy)
out_tensor.backward(d_loss_tensor)
dx_tensor = x_tensor.grad
print("dx_numpy : \n", dx_numpy)
print("dx_tensor : \n", dx_tensor.data.numpy())

print("\n----- 測試階段 -----")
train_flag = False
drop_out_tensor.eval()

out_tensor = drop_out_tensor(x_tensor)
mask = out_tensor > 0
mask = mask.data.numpy()
out_numpy = drop_out_numpy(x_numpy, mask, train_flg=train_flag)

print("test mask : \n", mask)
print("test x : \n", x_numpy)
print("numpy out : \n", out_numpy)
print("tensor out : \n", out_tensor.data.numpy())

"""
代碼輸出 :

----- 訓練階段 -----
train mask : 
 [[1 0 1 0 1 0 0]
 [0 0 1 1 1 0 1]
 [1 1 0 0 0 1 0]]
train x : 
 [[ 0.154163  0.74005   0.263315  0.533739  0.014575  0.918747  0.900715]
 [ 0.033421  0.956949  0.137209  0.283828  0.606083  0.944225  0.852736]
 [ 0.002259  0.521226  0.552038  0.485377  0.768134  0.160717  0.76456 ]]
numpy out : 
 [[ 0.280296  0.        0.478755  0.        0.0265    0.        0.      ]
 [ 0.        0.        0.249471  0.516052  1.101969  0.        1.550428]
 [ 0.004108  0.947684  0.        0.        0.        0.292212  0.      ]]
tensor out : 
 [[ 0.280296  0.        0.478755  0.        0.0265    0.        0.      ]
 [ 0.        0.        0.249471  0.516052  1.101969  0.        1.550428]
 [ 0.004108  0.947684  0.        0.        0.        0.292212  0.      ]]

----- 反向傳播 -----
dx_numpy : 
 [[ 0.037836  0.        0.211405  0.        1.220823  0.        0.      ]
 [ 0.        0.        1.277495  0.595581  0.60845   0.        1.135604]
 [ 1.727843  1.39541   0.        0.        0.        0.728421  0.      ]]
dx_tensor : 
 [[ 0.037836  0.        0.211405  0.        1.220823  0.        0.      ]
 [ 0.        0.        1.277495  0.595581  0.60845   0.        1.135604]
 [ 1.727843  1.39541   0.        0.        0.        0.728421  0.      ]]

----- 測試階段 -----
test mask : 
 [[1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1]]
test x : 
 [[ 0.154163  0.74005   0.263315  0.533739  0.014575  0.918747  0.900715]
 [ 0.033421  0.956949  0.137209  0.283828  0.606083  0.944225  0.852736]
 [ 0.002259  0.521226  0.552038  0.485377  0.768134  0.160717  0.76456 ]]
numpy out : 
 [[ 0.154163  0.74005   0.263315  0.533739  0.014575  0.918747  0.900715]
 [ 0.033421  0.956949  0.137209  0.283828  0.606083  0.944225  0.852736]
 [ 0.002259  0.521226  0.552038  0.485377  0.768134  0.160717  0.76456 ]]
tensor out : 
 [[ 0.154163  0.74005   0.263315  0.533739  0.014575  0.918747  0.900715]
 [ 0.033421  0.956949  0.137209  0.283828  0.606083  0.944225  0.852736]
 [ 0.002259  0.521226  0.552038  0.485377  0.768134  0.160717  0.76456 ]]
"""
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章