相關
原理和詳細解釋, 請參考 :
池化層MaxPool函數詳解及反向傳播的公式推導
https://blog.csdn.net/oBrightLamp/article/details/84635346
正文
import torch
import numpy as np
class MaxPool2D:
def __init__(self, kernel_size=(2, 2), stride=2):
self.stride = stride
self.kernel_size = kernel_size
self.w_height = kernel_size[0]
self.w_width = kernel_size[1]
self.x = None
self.in_height = None
self.in_width = None
self.out_height = None
self.out_width = None
self.arg_max = None
def __call__(self, x):
self.x = x
self.in_height = np.shape(x)[0]
self.in_width = np.shape(x)[1]
self.out_height = int((self.in_height - self.w_height) / self.stride) + 1
self.out_width = int((self.in_width - self.w_width) / self.stride) + 1
out = np.zeros((self.out_height, self.out_width))
self.arg_max = np.zeros_like(out, dtype=np.int32)
for i in range(self.out_height):
for j in range(self.out_width):
start_i = i * self.stride
start_j = j * self.stride
end_i = start_i + self.w_height
end_j = start_j + self.w_width
out[i, j] = np.max(x[start_i: end_i, start_j: end_j])
self.arg_max[i, j] = np.argmax(x[start_i: end_i, start_j: end_j])
self.arg_max = self.arg_max
return out
def backward(self, d_loss):
dx = np.zeros_like(self.x)
for i in range(self.out_height):
for j in range(self.out_width):
start_i = i * self.stride
start_j = j * self.stride
end_i = start_i + self.w_height
end_j = start_j + self.w_width
index = np.unravel_index(self.arg_max[i, j], self.kernel_size)
dx[start_i: end_i, start_j: end_j][index] = d_loss[i, j]
return dx
np.set_printoptions(precision=8, suppress=True, linewidth=120)
np.random.seed(123)
x_numpy = np.random.random((1, 1, 6, 8))
x_tensor = torch.tensor(x_numpy, requires_grad=True)
max_pool_tensor = torch.nn.MaxPool2d((2, 2), 2)
max_pool_numpy = MaxPool2D((2, 2), stride=2)
out_numpy = max_pool_numpy(x_numpy[0, 0])
out_tensor = max_pool_tensor(x_tensor)
d_loss_numpy = np.random.random(out_tensor.shape)
d_loss_tensor = torch.tensor(d_loss_numpy, requires_grad=True)
out_tensor.backward(d_loss_tensor)
dx_numpy = max_pool_numpy.backward(d_loss_numpy[0, 0])
dx_tensor = x_tensor.grad
print("out_numpy \n", out_numpy)
print("out_tensor \n", out_tensor.data.numpy())
print("dx_numpy \n", dx_numpy)
print("dx_tensor \n", dx_tensor.data.numpy())
"""
代碼輸出:
out_numpy
[[ 0.69646919 0.72904971 0.71946897 0.9807642 ]
[ 0.72244338 0.53182759 0.84943179 0.72445532]
[ 0.62395295 0.42583029 0.89338916 0.98555979]]
out_tensor
[[[[ 0.69646919 0.72904971 0.71946897 0.9807642 ]
[ 0.72244338 0.53182759 0.84943179 0.72445532]
[ 0.62395295 0.42583029 0.89338916 0.98555979]]]]
dx_numpy
[[ 0.51948512 0. 0. 0. 0.12062867 0. 0.8263408 0. ]
[ 0. 0. 0. 0.61289453 0. 0. 0. 0. ]
[ 0. 0. 0. 0.54506801 0. 0.34276383 0.30412079 0. ]
[ 0.60306013 0. 0. 0. 0. 0. 0. 0. ]
[ 0. 0. 0.68130077 0. 0. 0.87545684 0. 0. ]
[ 0.41702221 0. 0. 0. 0. 0. 0. 0.51042234]]
dx_tensor
[[[[ 0.51948512 0. 0. 0. 0.12062867 0. 0.8263408 0. ]
[ 0. 0. 0. 0.61289453 0. 0. 0. 0. ]
[ 0. 0. 0. 0.54506801 0. 0.34276383 0.30412079 0. ]
[ 0.60306013 0. 0. 0. 0. 0. 0. 0. ]
[ 0. 0. 0.68130077 0. 0. 0.87545684 0. 0. ]
[ 0.41702221 0. 0. 0. 0. 0. 0. 0.51042234]]]]
"""