CV之街景字符編碼識別三----pytorch之定義網絡,損失函數和優化器

一、定義網絡(Net)

用一個名爲Net的類定義
需要繼承torch.nn中的nn.Module(注意M大寫)
Net類包括初始化函數和forward函數兩部分

1)初始化
init_(self): 放置有可學習參數的層(注意init前後均是兩個下劃線)
a)對nn.Module初始化: super(Net, self)init()
b)定義卷積和全連接操作(用到nn.Conv2d(), nn.Linear())
2)前向操作
forward(self, x)
輸入x,按照網絡前向傳播步驟,調用初始化中定義的卷積和全連接操作,得到最後輸出,並return。

如下簡單定義一個cnn模型:

class SVHN_Model1(nn.Module):
#初始化
    def __init__(self):
        super(SVHN_Model1,self).__init__()
        ##CNN提取模塊
        self.cnn=nn.Sequential(
            nn.Conv2d(3,16,kernel_size=(3,3),stride=(2,2)),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16,32,kernel_size=(3,3),stride=(2,2)),
            nn.ReLU()
            nn.MaxPool2d(2),
            
        )
        self.fc1 = nn.Linear(32*3*7,11)
        self.fc2 = nn.Linear(32*3*7,11)
        self.fc3 = nn.Linear(32*3*7,11)
        self.fc4 = nn.Linear(32*3*7,11)
        self.fc5 = nn.Linear(32*3*7,11)
        self.fc6 = nn.Linear(32*3*7,11)
       #前向傳播
    def forword(self,img):
        feat=self.cnn(img)
        feat=feat.view(feat.shape[0],-1)
        c1=self.fc1(feat)
        c2=self.fc2(feat)
        c3=self.fc3(feat)
        c4=self.fc4(feat)
        c5=self.fc5(feat)
        c6=self.fc6(feat)
        return c1,c2,c3,c4,c5,c6
    model=SVHN_Model1()
    

二、定義損失函數和優化器

損失函數評估結果與label間的差距,通過backward損失函數,可以計算出每個參數的梯度,然後通過
優化器調整參數
損失函數nn中已定義好
import torch.optim as optim
優化器在optim中定義好,調用即可。

optim:
在這裏插入圖片描述

在這裏插入圖片描述

#損失函數
criterion = nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model,parameters(),0.005)
loss plot,c0_plot=[],[]
#迭代10個Epoch
for epoch in range (10):
    for data in train_loader:
        c0,c1,c3,c3,c4,c5=model(data[0])
        loss=criterion(c0,data[1][:,0])+\
             criterion(c1,data[1][:,1])+\
             criterion(c2,data[1][:,2])+\
             criterion(c3,data[1][:,3])+\
             criterion(c4,data[1][:,4])+\
             criterion(c5,data[1][:,5])
        loss /=6
        optimizer.zero_grad()
        loss.backward()
        optimizer.step
        
        loss_plot.append(loss.item())
        c0_plot.append((c0.argmax(1)==data[1][:,0]).sum().item()*1.0/c0.shape[0])
print(epoch)
             
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章