pytorch 入門筆記 | 談談pytorch的框架特色

pytorch 可以說是深度學習入門首選的框架,語法特點特別接近numpy,上手簡單。作爲一門流行的框架,總有它流行的原因,筆者認爲這是pytorch框架的一些特色所決定的,以下內容來源筆者在入門學習中的體會,因此作文總結。

近期我簡單入門了一下深度學習,對 pytorch 有了一定的掌握和認識,不得不感慨 pytorch 大法好,對深度學習新手特別友好,和numpy有着相似的語法特點,但有一些專門爲深度學習設計的框架特色,以下結合個人所學和體會,得出以下兩大特色:

  1. 聲明張量矩陣,可以自動計算梯度,省去了許多計算代碼;
  2. 快速搭建神經網絡,提供簡單易懂的模型算法接口

 

自動計算梯度

在 pytorch 中矩陣變量是以張量(tensor)來聲明定義的,我們可以在聲明張量的時候,決定該變量是否自動計算梯度。

以下給出簡單的例子。

簡單的線性模型 - 1

import torch
import torchvision 
import torch.nn as nn
import numpy as np
import torchvision.transforms as transforms

x = torch.tensor(1., requires_grad = True)
w = torch.tensor(2., requires_grad = True)
b = torch.tensor(3., requires_grad = True)
print(x, w, b)
y = w * x + b
print(y)

print(x.grad)
print(w.grad)
print(b.grad)

# compute gradient
y.backward()

print(x.grad)
print(w.grad)
print(b.grad)

輸出結果如下:

簡單線性模型 - 2

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt


# Hyper-parameters
input_size = 1
output_size = 1
num_epochs = 60
learning_rate = 0.001

# Toy dataset
x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168], 
                    [9.779], [6.182], [7.59], [2.167], [7.042], 
                    [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)

y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573], 
                    [3.366], [2.596], [2.53], [1.221], [2.827], 
                    [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)

# Linear regression model
model = nn.Linear(input_size, output_size)

# Define the loss function
criterion = nn.MSELoss()
# Define the optimiter to solve
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)  

# Train the model
for epoch in range(num_epochs):
    # Convert numpy arrays to torch tensors
    inputs = torch.from_numpy(x_train)
    targets = torch.from_numpy(y_train)

    # Forward pass
    outputs = model(inputs)
    
    # calculate loss fucntion value.
    loss = criterion(outputs, targets)
    
    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    
    # update the parameter
    optimizer.step()
    
    if (epoch+1) % 5 == 0:
        print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

# Plot the graph
predicted = model(torch.from_numpy(x_train)).detach().numpy()
plt.plot(x_train, y_train, 'ro', label='Original data')
plt.plot(x_train, predicted, label='Fitted line')
plt.legend()
plt.show()

# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

 

快速搭建神經網絡

pytorch 提供的nn模塊可以幫助我們很快定義好一個網絡結構,主要有幾個步驟:

  1. 定義模型,損失函數,優化求解方法
  2. 開始訓練,輸出預測結果,計算損失函數值,反向更新參數
  3. 直到迭代結束,模型訓練成功。

以下給出一個簡單的三層神經網絡結構。

import torch
import matplotlib.pyplot as plt

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(reduction='sum')

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learn_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr = learn_rate)

t_list = []
loss_list = []

for t in range(500):
    # output the predicted value
    y_pred = model(x)
    
    # calculate loss function value
    loss = loss_fn(y_pred, y)
    
    # for visualization
    t_list.append(t)
    loss_list.append(loss)
    
    # make the optimizer's gradient to be zero
    optimizer.zero_grad()
    
    # calculate gradient and update parameter
    loss.backward()
    optimizer.step()
    
plt.plot(t_list, loss_list, label = 'loss')
plt.show()

隨迭代次數的增加,損失函數值逐漸變小。

 

pytorch 的魅力遠不止如此,本菜以後更熟練的時候在總結!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章