Pytorch 學習(五):Pytorch 實現多層感知機(MLP)

Pytorch 實現多層感知機(MLP)

本方法總結自《動手學深度學習》(Pytorch版)github項目

實現多層感知器(Multlayer Perceptron)同樣遵循以下步驟:

  • 數據集讀取
  • 模型搭建和參數初始化
  • 損失函數和下降器構建
  • 模型訓練

方法一:從零開始實現

import torch
import torch.nn as nn
import numpy as np
import d2lzh_pytorch as d2l

# 各層節點數
num_i = 28 * 28
num_h = 256
num_o = 10

# 構建數據
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

# 參數初始化
w1 = torch.tensor(np.random.normal(0, 0.01, (num_i, num_h)), dtype=torch.float32, requires_grad=True)
b1 = torch.zeros(num_h, requires_grad=True)
w2 = torch.tensor(np.random.normal(0, 0.01, (num_h, num_o)), dtype=torch.float32, requires_grad=True)
b2 = torch.zeros(num_o, requires_grad=True)
params = [w1, b1, w2, b2]

# 激活函數
def relu(x):
    return torch.max(x, torch.tensor(0.0))

# 模型構建
def net(x):
    x = x.view(-1, num_i)
    h = relu(x.mm(w1) + b1)
    o = h.mm(w2) + b2
    return o

# 損失函數
loss = nn.CrossEntropyLoss()

# 訓練模型
num_epochs = 5
lr = 100.0
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)

方法二:能調包就不實現

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import d2lzh_pytorch as d2l

# node number of MLP Layer
num_i, num_h, num_o = 28 * 28, 256, 10

# data load
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

# network build
class MLP(nn.Module):
    def __init__(self, n_i, n_h, n_o):
        super(MLP, self).__init__()
        self.flatten = d2l.FlattenLayer()
        self.linear1 = nn.Linear(n_i, n_h)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(n_h, n_o)

    def forward(self, input):
        return self.linear2(self.relu(self.linear1(self.flatten(input))))

net = MLP(num_i, num_h, num_o)
for param in net.parameters():
    init.normal_(param, mean=0, std=0.01)

# loss
loss = nn.CrossEntropyLoss()

# optimizer
optimizer = optim.SGD(net.parameters(), lr=0.5)

# train
num_epochs = 5
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, optimizer=optimizer)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章