CV之街景字符编码识别三----pytorch之定义网络,损失函数和优化器

一、定义网络(Net)

用一个名为Net的类定义
需要继承torch.nn中的nn.Module(注意M大写)
Net类包括初始化函数和forward函数两部分

1)初始化
init_(self): 放置有可学习参数的层(注意init前后均是两个下划线)
a)对nn.Module初始化: super(Net, self)init()
b)定义卷积和全连接操作(用到nn.Conv2d(), nn.Linear())
2)前向操作
forward(self, x)
输入x,按照网络前向传播步骤,调用初始化中定义的卷积和全连接操作,得到最后输出,并return。

如下简单定义一个cnn模型:

class SVHN_Model1(nn.Module):
#初始化
    def __init__(self):
        super(SVHN_Model1,self).__init__()
        ##CNN提取模块
        self.cnn=nn.Sequential(
            nn.Conv2d(3,16,kernel_size=(3,3),stride=(2,2)),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16,32,kernel_size=(3,3),stride=(2,2)),
            nn.ReLU()
            nn.MaxPool2d(2),
            
        )
        self.fc1 = nn.Linear(32*3*7,11)
        self.fc2 = nn.Linear(32*3*7,11)
        self.fc3 = nn.Linear(32*3*7,11)
        self.fc4 = nn.Linear(32*3*7,11)
        self.fc5 = nn.Linear(32*3*7,11)
        self.fc6 = nn.Linear(32*3*7,11)
       #前向传播
    def forword(self,img):
        feat=self.cnn(img)
        feat=feat.view(feat.shape[0],-1)
        c1=self.fc1(feat)
        c2=self.fc2(feat)
        c3=self.fc3(feat)
        c4=self.fc4(feat)
        c5=self.fc5(feat)
        c6=self.fc6(feat)
        return c1,c2,c3,c4,c5,c6
    model=SVHN_Model1()
    

二、定义损失函数和优化器

损失函数评估结果与label间的差距,通过backward损失函数,可以计算出每个参数的梯度,然后通过
优化器调整参数
损失函数nn中已定义好
import torch.optim as optim
优化器在optim中定义好,调用即可。

optim:
在这里插入图片描述

在这里插入图片描述

#损失函数
criterion = nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model,parameters(),0.005)
loss plot,c0_plot=[],[]
#迭代10个Epoch
for epoch in range (10):
    for data in train_loader:
        c0,c1,c3,c3,c4,c5=model(data[0])
        loss=criterion(c0,data[1][:,0])+\
             criterion(c1,data[1][:,1])+\
             criterion(c2,data[1][:,2])+\
             criterion(c3,data[1][:,3])+\
             criterion(c4,data[1][:,4])+\
             criterion(c5,data[1][:,5])
        loss /=6
        optimizer.zero_grad()
        loss.backward()
        optimizer.step
        
        loss_plot.append(loss.item())
        c0_plot.append((c0.argmax(1)==data[1][:,0]).sum().item()*1.0/c0.shape[0])
print(epoch)
             
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章