Python和PyTorch對比實現池化層MaxPool函數及反向傳播

相關

原理和詳細解釋, 請參考 :
池化層MaxPool函數詳解及反向傳播的公式推導
https://blog.csdn.net/oBrightLamp/article/details/84635346

正文

import torch
import numpy as np


class MaxPool2D:
    def __init__(self, kernel_size=(2, 2), stride=2):
        self.stride = stride

        self.kernel_size = kernel_size
        self.w_height = kernel_size[0]
        self.w_width = kernel_size[1]

        self.x = None

        self.in_height = None
        self.in_width = None

        self.out_height = None
        self.out_width = None

        self.arg_max = None

    def __call__(self, x):
        self.x = x
        self.in_height = np.shape(x)[0]
        self.in_width = np.shape(x)[1]

        self.out_height = int((self.in_height - self.w_height) / self.stride) + 1
        self.out_width = int((self.in_width - self.w_width) / self.stride) + 1

        out = np.zeros((self.out_height, self.out_width))
        self.arg_max = np.zeros_like(out, dtype=np.int32)
        for i in range(self.out_height):
            for j in range(self.out_width):
                start_i = i * self.stride
                start_j = j * self.stride
                end_i = start_i + self.w_height
                end_j = start_j + self.w_width
                out[i, j] = np.max(x[start_i: end_i, start_j: end_j])
                self.arg_max[i, j] = np.argmax(x[start_i: end_i, start_j: end_j])

        self.arg_max = self.arg_max
        return out

    def backward(self, d_loss):
        dx = np.zeros_like(self.x)
        for i in range(self.out_height):
            for j in range(self.out_width):
                start_i = i * self.stride
                start_j = j * self.stride
                end_i = start_i + self.w_height
                end_j = start_j + self.w_width
                index = np.unravel_index(self.arg_max[i, j], self.kernel_size)
                dx[start_i: end_i, start_j: end_j][index] = d_loss[i, j]
        return dx


np.set_printoptions(precision=8, suppress=True, linewidth=120)
np.random.seed(123)

x_numpy = np.random.random((1, 1, 6, 8))
x_tensor = torch.tensor(x_numpy, requires_grad=True)

max_pool_tensor = torch.nn.MaxPool2d((2, 2), 2)
max_pool_numpy = MaxPool2D((2, 2), stride=2)

out_numpy = max_pool_numpy(x_numpy[0, 0])
out_tensor = max_pool_tensor(x_tensor)

d_loss_numpy = np.random.random(out_tensor.shape)
d_loss_tensor = torch.tensor(d_loss_numpy, requires_grad=True)
out_tensor.backward(d_loss_tensor)

dx_numpy = max_pool_numpy.backward(d_loss_numpy[0, 0])
dx_tensor = x_tensor.grad

print("out_numpy \n", out_numpy)
print("out_tensor \n", out_tensor.data.numpy())

print("dx_numpy \n", dx_numpy)
print("dx_tensor \n", dx_tensor.data.numpy())

"""
代碼輸出:
out_numpy 
 [[ 0.69646919  0.72904971  0.71946897  0.9807642 ]
 [ 0.72244338  0.53182759  0.84943179  0.72445532]
 [ 0.62395295  0.42583029  0.89338916  0.98555979]]
out_tensor 
 [[[[ 0.69646919  0.72904971  0.71946897  0.9807642 ]
   [ 0.72244338  0.53182759  0.84943179  0.72445532]
   [ 0.62395295  0.42583029  0.89338916  0.98555979]]]]
dx_numpy 
 [[ 0.51948512  0.          0.          0.          0.12062867  0.          0.8263408   0.        ]
 [ 0.          0.          0.          0.61289453  0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.54506801  0.          0.34276383  0.30412079  0.        ]
 [ 0.60306013  0.          0.          0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.68130077  0.          0.          0.87545684  0.          0.        ]
 [ 0.41702221  0.          0.          0.          0.          0.          0.          0.51042234]]
dx_tensor 
 [[[[ 0.51948512  0.          0.          0.          0.12062867  0.          0.8263408   0.        ]
   [ 0.          0.          0.          0.61289453  0.          0.          0.          0.        ]
   [ 0.          0.          0.          0.54506801  0.          0.34276383  0.30412079  0.        ]
   [ 0.60306013  0.          0.          0.          0.          0.          0.          0.        ]
   [ 0.          0.          0.68130077  0.          0.          0.87545684  0.          0.        ]
   [ 0.41702221  0.          0.          0.          0.          0.          0.          0.51042234]]]]
"""
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章