純Python和PyTorch對比實現循環神經網絡RNN及反向傳播

摘要

本文使用純 Python 和 PyTorch 對比實現循環神經網絡RNN及其反向傳播

相關

原理和詳細解釋, 請參考:

循環神經網絡RNNCell單元詳解及反向傳播的梯度求導

https://blog.csdn.net/oBrightLamp/article/details/85015325

正文

import torch
import numpy as np


class RNNCell:
    def __init__(self, weight_ih, weight_hh,
                 bias_ih, bias_hh):
        self.weight_ih = weight_ih
        self.weight_hh = weight_hh
        self.bias_ih = bias_ih
        self.bias_hh = bias_hh

        self.x_stack = []
        self.dx_list = []
        self.dw_ih_stack = []
        self.dw_hh_stack = []
        self.db_ih_stack = []
        self.db_hh_stack = []

        self.prev_hidden_stack = []
        self.next_hidden_stack = []

        # temporary cache
        self.prev_dh = None

    def __call__(self, x, prev_hidden):
        self.x_stack.append(x)

        next_h = np.tanh(
            np.dot(x, self.weight_ih.T)
            + np.dot(prev_hidden, self.weight_hh.T)
            + self.bias_ih + self.bias_hh)

        self.prev_hidden_stack.append(prev_hidden)
        self.next_hidden_stack.append(next_h)
        # clean cache
        self.prev_dh = np.zeros(next_h.shape)
        return next_h

    def backward(self, dh):
        x = self.x_stack.pop()
        prev_hidden = self.prev_hidden_stack.pop()
        next_hidden = self.next_hidden_stack.pop()

        d_tanh = (dh + self.prev_dh) * (1 - next_hidden ** 2)
        self.prev_dh = np.dot(d_tanh, self.weight_hh)

        dx = np.dot(d_tanh, self.weight_ih)
        self.dx_list.insert(0, dx)

        dw_ih = np.dot(d_tanh.T, x)
        self.dw_ih_stack.append(dw_ih)

        dw_hh = np.dot(d_tanh.T, prev_hidden)
        self.dw_hh_stack.append(dw_hh)

        self.db_ih_stack.append(d_tanh)
        self.db_hh_stack.append(d_tanh)

        return self.dx_list


if __name__ == '__main__':
    np.random.seed(123)
    torch.random.manual_seed(123)
    np.set_printoptions(precision=6, suppress=True)

    rnn_PyTorch = torch.nn.RNN(4, 5).double()
    rnn_numpy = RNNCell(rnn_PyTorch.all_weights[0][0].data.numpy(),
                        rnn_PyTorch.all_weights[0][1].data.numpy(),
                        rnn_PyTorch.all_weights[0][2].data.numpy(),
                        rnn_PyTorch.all_weights[0][3].data.numpy())

    nums = 3
    x3_numpy = np.random.random((nums, 3, 4))
    x3_tensor = torch.tensor(x3_numpy, requires_grad=True)

    h3_numpy = np.random.random((1, 3, 5))
    h3_tensor = torch.tensor(h3_numpy, requires_grad=True)

    dh_numpy = np.random.random((nums, 3, 5))
    dh_tensor = torch.tensor(dh_numpy, requires_grad=True)

    h3_tensor = rnn_PyTorch(x3_tensor, h3_tensor)
    h_numpy_list = []

    h_numpy = h3_numpy[0]
    for i in range(nums):
        h_numpy = rnn_numpy(x3_numpy[i], h_numpy)
        h_numpy_list.append(h_numpy)

    h3_tensor[0].backward(dh_tensor)
    for i in reversed(range(nums)):
        rnn_numpy.backward(dh_numpy[i])

    print("numpy_hidden :\n", np.array(h_numpy_list))
    print("tensor_hidden :\n", h3_tensor[0].data.numpy())
    print("------")

    print("dx_numpy :\n", np.array(rnn_numpy.dx_list))
    print("dx_tensor :\n", x3_tensor.grad.data.numpy())
    print("------")

    print("dw_ih_numpy :\n",
          np.sum(rnn_numpy.dw_ih_stack, axis=0))
    print("dw_ih_tensor :\n",
          rnn_PyTorch.all_weights[0][0].grad.data.numpy())
    print("------")

    print("dw_hh_numpy :\n",
          np.sum(rnn_numpy.dw_hh_stack, axis=0))
    print("dw_hh_tensor :\n",
          rnn_PyTorch.all_weights[0][1].grad.data.numpy())
    print("------")

    print("db_ih_numpy :\n",
          np.sum(rnn_numpy.db_ih_stack, axis=(0, 1)))
    print("db_ih_tensor :\n",
          rnn_PyTorch.all_weights[0][2].grad.data.numpy())
    print("------")
    print("db_hh_numpy :\n",
          np.sum(rnn_numpy.db_hh_stack, axis=(0, 1)))
    print("db_hh_tensor :\n",
          rnn_PyTorch.all_weights[0][3].grad.data.numpy())

    """
    代碼輸出
    numpy_hidden :
     [[[ 0.4686   -0.298203  0.741399 -0.446474  0.019391]
      [ 0.365172 -0.361254  0.426838 -0.448951  0.331553]
      [ 0.589187 -0.188248  0.684941 -0.45859   0.190099]]
    
     [[ 0.146213 -0.306517  0.297109  0.370957 -0.040084]
      [-0.009201 -0.365735  0.333659  0.486789  0.061897]
      [ 0.030064 -0.282985  0.42643   0.025871  0.026388]]
    
     [[ 0.225432 -0.015057  0.116555  0.080901  0.260097]
      [ 0.368327  0.258664  0.357446  0.177961  0.55928 ]
      [ 0.103317 -0.029123  0.182535  0.216085  0.264766]]]
    tensor_hidden :
     [[[ 0.4686   -0.298203  0.741399 -0.446474  0.019391]
      [ 0.365172 -0.361254  0.426838 -0.448951  0.331553]
      [ 0.589187 -0.188248  0.684941 -0.45859   0.190099]]
    
     [[ 0.146213 -0.306517  0.297109  0.370957 -0.040084]
      [-0.009201 -0.365735  0.333659  0.486789  0.061897]
      [ 0.030064 -0.282985  0.42643   0.025871  0.026388]]
    
     [[ 0.225432 -0.015057  0.116555  0.080901  0.260097]
      [ 0.368327  0.258664  0.357446  0.177961  0.55928 ]
      [ 0.103317 -0.029123  0.182535  0.216085  0.264766]]]
    ------
    dx_numpy :
     [[[-0.643965  0.215931 -0.476378  0.072387]
      [-1.221727  0.221325 -0.757251  0.092991]
      [-0.59872  -0.065826 -0.390795  0.037424]]
    
     [[-0.537631 -0.303022 -0.364839  0.214627]
      [-0.815198  0.392338 -0.564135  0.217464]
      [-0.931365 -0.254144 -0.561227  0.164795]]
    
     [[-1.055966  0.249554 -0.623127  0.009784]
      [-0.45858   0.108994 -0.240168  0.117779]
      [-0.957469  0.315386 -0.616814  0.205634]]]
    dx_tensor :
     [[[-0.643965  0.215931 -0.476378  0.072387]
      [-1.221727  0.221325 -0.757251  0.092991]
      [-0.59872  -0.065826 -0.390795  0.037424]]
    
     [[-0.537631 -0.303022 -0.364839  0.214627]
      [-0.815198  0.392338 -0.564135  0.217464]
      [-0.931365 -0.254144 -0.561227  0.164795]]
    
     [[-1.055966  0.249554 -0.623127  0.009784]
      [-0.45858   0.108994 -0.240168  0.117779]
      [-0.957469  0.315386 -0.616814  0.205634]]]
    ------
    dw_ih_numpy :
     [[ 3.918335  2.958509  3.725173  4.157478]
     [ 1.261197  0.812825  1.10621   0.97753 ]
     [ 2.216469  1.718251  2.366936  2.324907]
     [ 3.85458   3.052212  3.643157  3.845696]
     [ 1.806807  1.50062   1.615917  1.521762]]
    dw_ih_tensor :
     [[ 3.918335  2.958509  3.725173  4.157478]
     [ 1.261197  0.812825  1.10621   0.97753 ]
     [ 2.216469  1.718251  2.366936  2.324907]
     [ 3.85458   3.052212  3.643157  3.845696]
     [ 1.806807  1.50062   1.615917  1.521762]]
    ------
    dw_hh_numpy :
     [[ 2.450078  0.243735  4.269672  0.577224  1.46911 ]
     [ 0.421015  0.372353  0.994656  0.962406  0.518992]
     [ 1.079054  0.042843  2.12169   0.863083  0.757618]
     [ 2.225794  0.188735  3.682347  0.934932  0.955984]
     [ 0.660546 -0.321076  1.554888  0.833449  0.605201]]
    dw_hh_tensor :
     [[ 2.450078  0.243735  4.269672  0.577224  1.46911 ]
     [ 0.421015  0.372353  0.994656  0.962406  0.518992]
     [ 1.079054  0.042843  2.12169   0.863083  0.757618]
     [ 2.225794  0.188735  3.682347  0.934932  0.955984]
     [ 0.660546 -0.321076  1.554888  0.833449  0.605201]]
    ------
    db_ih_numpy :
     [ 7.568411  2.175445  4.335336  6.820628  3.51003 ]
    db_ih_tensor :
     [ 7.568411  2.175445  4.335336  6.820628  3.51003 ]
    ------
    db_hh_numpy :
     [ 7.568411  2.175445  4.335336  6.820628  3.51003 ]
    db_hh_tensor :
     [ 7.568411  2.175445  4.335336  6.820628  3.51003 ]
    """
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章