參考中文官方,詳情參考:PyTorch 如何自定義 Module
1.自定義Module
Module 是 pytorch 組織神經網絡的基本方式。Module 包含了模型的參數以及計算邏輯。Function 承載了實際的功能,定義了前向和後向的計算邏輯。
下面以最簡單的 MLP 網絡結構爲例,介紹下如何實現自定義網絡結構。完整代碼可以參見repo。
1.1 Function
Function 是 pytorch 自動求導機制的核心類。Function 是無參數或者說無狀態的,它只負責接收輸入,返回相應的輸出;對於反向,它接收輸出相應的梯度,返回輸入相應的梯度。
這裏我們只關注如何自定義 Function。Function 的定義見源碼。下面是簡化的代碼段:
class Function(object):
def forward(self, *input):
raise NotImplementedError
def backward(self, *grad_output):
raise NotImplementedError
forward 和 backward 的輸入和輸出都是 Tensor 對象。
Function 對象是 callable 的,即可以通過()的方式進行調用。其中調用的輸入和輸出都爲 Variable 對象。下面的代碼示例瞭如何實現一個 ReLU 激活函數並進行調用:
import torch
from torch.autograd import Function
class ReLUF(Function):
def forward(self, input):
self.save_for_backward(input)
output = input.clamp(min=0)
return output
def backward(self, output_grad):
input = self.to_save[0]
input_grad = output_grad.clone()
input_grad[input < 0] = 0
return input_grad
## Test
if __name__ == "__main__":
from torch.autograd import Variable
torch.manual_seed(1111)
a = torch.randn(2, 3)
va = Variable(a, requires_grad=True)
vb = ReLUF()(va)
print va.data, vb.data
vb.backward(torch.ones(va.size()))
print vb.grad.data, va.grad.data
如果 backward 中需要用到 forward 的輸入,需要在 forward 中顯式的保存需要的輸入。在上面的代碼中,forward 利用self.save_for_backward函數,將輸入暫時保存,並在 backward 中利用saved_tensors (python tuple 對象) 取出。
顯然,forward 的輸入應該和 backward 的輸入相對應;同時,forward 的輸出應該和 backward 的輸入相匹配。
注意:
由於 Function 可能需要暫存 input tensor,因此,建議不復用 Function 對象,以避免遇到內存提前釋放的問題。
如示例代碼所示,forward的每次調用都重新生成一個 ReLUF 對象,而不能在初始化時生成在 forward 中反覆調用。(意思示例代碼是對的,其是每個fc都生成一個新對象。因爲每個對象都自己在管理自己的的權重跟輸入)
2.2 Module
類似於 Function,Module 對象也是 callable 是,輸入和輸出也是 Variable。不同的是,Module 是[可以]有參數的。Module 包含兩個主要部分:參數及計算邏輯(Function 調用)。由於ReLU激活函數沒有參數,這裏我們以最基本的全連接層爲例來說明如何自定義Module。
全連接層的運算邏輯定義如下 Function:
import torch
from torch.autograd import Function
class LinearF(Function):
def forward(self, input, weight, bias=None):
self.save_for_backward(input, weight, bias)
output = torch.mm(input, weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
def backward(self, grad_output):
input, weight, bias = self.saved_tensors
grad_input = grad_weight = grad_bias = None
if self.needs_input_grad[0]:
grad_input = torch.mm(grad_output, weight)
if self.needs_input_grad[1]:
grad_weight = torch.mm(grad_output.t(), input)
if bias is not None and self.needs_input_grad[2]:
grad_bias = grad_output.sum(0).squeeze(0)
if bias is not None:
return grad_input, grad_weight, grad_bias
else:
return grad_input, grad_weight
needs_input_grad 爲一個元素爲 bool 型的 tuple,長度與 forward 的參數數量相同,用來標識各個輸入是否輸入計算梯度;對於無需梯度的輸入,可以減少不必要的計算。
Function(此處爲 LinearF) 定義了基本的計算邏輯,Module 只需要在初始化時爲參數分配內存空間,並在計算時,將參數傳遞給相應的 Function 對象。代碼如下:
import torch
import torch.nn as nn
class Linear(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
def forward(self, input):
return LinearF()(input, self.weight, self.bias)
需要注意的是,參數是內存空間由 tensor 對象維護,但 tensor 需要包裝爲一個Parameter 對象。Parameter 是 Variable 的特殊子類,僅有是不同是 Parameter 默認requires_grad爲 True。Varaible 是自動求導機制的核心類,此處暫不介紹,參見教程。
2.3自定義循環神經網絡(RNN)
可運行代碼參考:
RNN
其中Parameters是Variable的一個子類,而且其是自動求導機制,所以我們在定義代碼網絡的時候基本不需要從在backwards,只要定義好forward就可以了
2.3 定義DoReFaNet
由於其壓縮的是權重,所以我們只需要定義一個卷積類,激活函數類,然後去獲取他裏面的權重,進行量化,定義網絡如下:
class AlexNet_Q(nn.Module):
def __init__(self, wbit, abit, num_classes=1000):
super(AlexNet_Q, self).__init__()
Conv2d = conv2d_Q_fn(w_bit=wbit)
Linear = linear_Q_fn(w_bit=wbit)
self.features = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),
nn.BatchNorm2d(96),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
Conv2d(96, 256, kernel_size=5, padding=2),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
activation_quantize_fn(a_bit=abit),
nn.MaxPool2d(kernel_size=3, stride=2),
Conv2d(256, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
activation_quantize_fn(a_bit=abit),
Conv2d(384, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
activation_quantize_fn(a_bit=abit),
Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
activation_quantize_fn(a_bit=abit),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.classifier = nn.Sequential(
Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
activation_quantize_fn(a_bit=abit),
Linear(4096, 4096),
nn.ReLU(inplace=True),
activation_quantize_fn(a_bit=abit),
nn.Linear(4096, num_classes),
)
for m in self.modules():
if isinstance(m, Conv2d) or isinstance(m, Linear):
init.xavier_normal_(m.weight.data)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), 256 * 6 * 6)
x = self.classifier(x)
return
其中自定義權重量化如下:
class weight_quantize_fn(nn.Module):
def __init__(self, w_bit):
super(weight_quantize_fn, self).__init__()
assert w_bit <= 8 or w_bit == 32
self.w_bit = w_bit
self.uniform_q = uniform_quantize(k=w_bit)
def forward(self, x):
if self.w_bit == 32:
weight_q = x
elif self.w_bit == 1:
E = torch.mean(torch.abs(x)).detach()
weight_q = self.uniform_q(x / E) * E
else:
weight = torch.tanh(x)
weight = weight / 2 / torch.max(torch.abs(weight)) + 0.5
weight_q = 2 * self.uniform_q(weight) - 1
return weight_q
def conv2d_Q_fn(w_bit):
class Conv2d_Q(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2d_Q, self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
self.w_bit = w_bit
self.quantize_fn = weight_quantize_fn(w_bit=w_bit)
def forward(self, input, order=None):
weight_q = self.quantize_fn(self.weight)
# print(np.unique(weight_q.detach().numpy()))
return F.conv2d(input, weight_q, self.bias, self.stride,
self.padding, self.dilation, self.groups)
return Conv2d_Q