pytorch使用自定義網絡

1、定義網絡結構

import torch.nn as nn

import torch.nn.functional as F

class Model(nn.Module):

     def __init__(self):

          super(Model,self).__init__()

          self.conv1=nn.Conv2d(1,20,5)

          self.conv2=nn.Conv2d(20,20,5)

    def forward(self,x):

          x=F.relu(self.conv1(x))

          return F.relu(self.conv2(x))

這個例子定義了一個只有兩層的網絡Model,其中兩個函數:

--初始化函數 __init__(self):定義了具體網絡有什麼層,這裏實際上沒有決定網絡的結構,也就是說將上面的例子中的self.conv1和self.conv2定義的前後順序調換是完全沒有影響的。

--forward函數定義了網絡的前向傳播的順序

2、網絡參數初始化

pytorch官方提供了多種初始化函數

torch.nn.init.uniform(tensor,a=0,b=1)

torch.nn.init.normal(tensor,mean=0,std=1)

torch.nn.init.constant(tensor,val)

torch.nn.init.xavier_uniform(tensor,gain=1)

初始化函數可以直接作用於神經網絡參數

1、對網絡的某一層參數進行初始化

import torch.nn as nn

import torch.nn.init as init

conv1=nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3)

init.xavier _uniform(conv1.weight)

init.constant(conv1.bas,0.1)

2、對整個網絡的參數進行初始化

def weights_init(m):

    if isinstance(m,nn.Conv2d):

        xavier(m.weight.data)

        xavier(m.bias.data)

下面舉一個例子,定義一個網絡MyNet,網絡由6層的卷積構成:

import torch
import torch.nn as nn
import torch.nn.init as init

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet,self).__init__()
        self.conv1=nn.ReLU(inplace=True)
        self.relu1=nn.ReLU(inplace=True)
        self.conv2=nn.Conv2d(64,128,7,padding=3)
        self.relu2=nn.ReLU(inplace=True)
        self.conv3=nn.Conv2d(128,256,5,padding=2)
        self.relu3=nn.ReLU(inplace=True)
        self.conv4=nn.Conv2d(256,128,5,padding=2)
        self.relu4=nn.ReLU(inplace=True)
        self.conv5=nn.conv2d(128,64,3,padding=1)
        self.relu5=nn.ReLU(inplace=True)
        self.conv6=nn.conv2d(64,6,3,padding=1)
        self.relu6=nn.ReLU(inplace=True)

        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                init.xavier_uniform(m.weight.data)
                init.constant(m.bias.data,0.1)

        def forward(self,x):
            x=self.relu1(self.conv1(x))
            x=self.relu2(self.conv2(x))
            x=self.relu3(self.conv3(x))
            x=self.relu4(self.conv4(x))
            x=self.relu5(self.conv5(x))
            f=self.relu6(self.conv6(x))
            return f

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章