博主在學習三值神經網絡時,使用了LeNet-5模型,編程代碼,需要對LeNet-5模型中的卷積層與全連接層進行自定義,搜索他人方法後,博主產生了一個疑問,絕大多數提供的自定義層方法都是繼承 nn.Module 模型,而這方法據說是官方提供(官網網址:PyTorch),自定義線性層代碼如下:
class Linear(nn.Module):
def __init__(self, input_features, output_features, bias=True):
super(Linear, self).__init__()
self.input_features = input_features
self.output_features = output_features
self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(output_features))
else:
self.register_parameter('bias', None)
self.weight.data.uniform_(-0.1, 0.1)
if bias is not None:
self.bias.data.uniform_(-0.1, 0.1)
def forward(self, input):
return LinearFunction.apply(input, self.weight, self.bias)
博主因涉獵PyTorch時長過短,對此瞭解不多,初始僅是覺得奇怪,爲何自定義層需要繼承nn.Module模型,乍一看像是自定義模型而非自定義層
因博主過於小白,當時就想,這麼誤人子弟的自定義方法不應該籠統地用nn.Module一齊實現,PyTorch是一個成熟的軟件,應該知道還有博主這種小小白沒辦法立刻上手以nn.Module爲依賴的自定義層方法(說實話,博主看了幾天纔看懂上面的代碼是啥意思……哭泣……)
所以博主又搜索了GitHub中與已發表論文相關的PyTorch深度學習代碼,終於發現了PyTorch自定義層的簡單方法,可以細節到區分具體層的種類,而不會產生混淆
以下爲博主使用的LeNet-5模型,爲直觀體現自定義層的方法,博主僅留下了自定義的卷積層與全連接層,即代碼中 TernaryConv2d(1, 32, kernel_size=5) 與 TernaryLinear(512, 10) 部分
# LeNet-5模型
class LeNet_5(nn.Module):
def __init__(self):
super(LeNet_5, self).__init__()
self.conv1 = TernaryConv2d(1, 32, kernel_size=5) # 卷積
self.fc1 = TernaryLinear(512, 10) # 全連接
def forward(self, x):
x = self.conv1(x)
x = self.fc1(x)
return x
以下爲博主自定義的卷積層(TernaryConv2d)與全連接層(TernaryLinear)
class TernaryConv2d(nn.Conv2d):
# def __init__(self, *args, **kwargs):
# 該方法接受任意個數的參數,其中不指定key的參數會以list形式保存到args變量中,指定key的參數會以dict的形式保存到kwargs變量中
def __init__(self, *args, **kwargs):
super(TernaryConv2d, self).__init__(*args, **kwargs)
def forward(self, input):
out = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return out
class TernaryLinear(nn.Linear):
def __init__(self, *args, **kwargs):
super(TernaryLinear, self).__init__(*args, **kwargs)
def forward(self, input):
out = F.linear(input, self.weight, self.bias)
return out
以卷積層爲例,類名可自行定義,但必須繼承 nn.Conv2d ,在該層初始化使用時,若無參數設置,將會默認卷積層的參數,默認參數可查閱官方文檔(在 Python API 中)獲知,這裏博主將其粘貼了過來:
CLASS torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')
而後,在 def init() 函數中可初始化參數,在 def forward() 函數中可對卷積層自行設計
若想自定義其他類型的層,只需修改繼承的 nn.Conv2d 即可,操作非常簡單且不會產生歧義
博主將模型與層的代碼整合如下,便於複製使用:
class TernaryConv2d(nn.Conv2d):
# def __init__(self, *args, **kwargs):
# 該方法接受任意個數的參數,其中不指定key的參數會以list形式保存到args變量中,指定key的參數會以dict的形式保存到kwargs變量中
def __init__(self, *args, **kwargs):
super(TernaryConv2d, self).__init__(*args, **kwargs)
def forward(self, input):
out = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return out
class TernaryLinear(nn.Linear):
def __init__(self, *args, **kwargs):
super(TernaryLinear, self).__init__(*args, **kwargs)
def forward(self, input):
out = F.linear(input, self.weight, self.bias)
return out
# LeNet-5模型
class LeNet_5(nn.Module):
def __init__(self):
super(LeNet_5, self).__init__()
self.conv1 = TernaryConv2d(1, 32, kernel_size=5) # 卷積
self.fc1 = TernaryLinear(512, 10) # 全連接
def forward(self, x):
x = self.conv1(x)
x = self.fc1(x)
return x
因官方模型類API過多,博主無法一一查看,博主猜測nn.Module模型中應該與自定義層所使用的nn.Conv2d有着某種繼承關係,所以可以如此使用