Python小練習:優化器torch.optim的使用

Python小練習:優化器torch.optim的使用

作者:凱魯嘎吉 - 博客園 http://www.cnblogs.com/kailugaji/

本文主要介紹Pytorch中優化器的使用方法,瞭解optimizer.zero_grad()、loss.backward()以及optimizer.step()函數的用法。

問題陳述:假設最小化目標函數爲$L = \sum\nolimits_{i = 1}^N {x_i^2} $。給定初始值$\left[ {x_1^{(0)},x_2^{(0)}} \right] = [5,{\rm{ }}10]$,求最優解$\hat x = \arg \min L$。

1. optim_test.py

 1 # -*- coding: utf-8 -*-
 2 # Author:凱魯嘎吉 Coral Gajic
 3 # https://www.cnblogs.com/kailugaji/
 4 # Python小練習:優化器torch.optim的使用
 5 # 假設損失函數爲loss = ∑(x^2),給定初始x(0),目的是找到最優的x,使得loss最小化
 6 # 以2維數據爲例,loss = x1^2 + x2^2
 7 # loss對x1的偏導爲2 * x1,loss對x2的偏導爲2 * x2
 8 '''
 9 部分參考:
10     https://www.cnblogs.com/zhouyang209117/p/16048331.html
11 '''
12 import torch
13 import torch.optim as optim
14 import numpy as np
15 import matplotlib.pyplot as plt
16 plt.rc('font',family='Times New Roman')
17 
18 # 方法一:
19 # 自己實現的優化
20 rate = 0.1 # 學習率
21 iteration = 30 # 迭代次數
22 num = 3
23 data = np.array([5.0, 10.0]) # 給定初始數據x(0)
24 for i in range(iteration):
25     loss = (data ** 2).sum(axis = 0) # 優化目標:最小化loss = x1^2 + x2^2
26     my_grad = 2 * data # 梯度d(loss)/dx:2 * x1, 2 * x2
27     print('%d.' %(i+1),
28           '數據:', np.around(data, num),
29           '\t損失函數:', np.around(loss, num),
30           '\t梯度:', np.around(my_grad, num)
31           )
32     data -= my_grad * rate # 優化更新:x = x - learning_rate * (d(loss)/dx)
33 
34 print('----------------------------------------------------------------')
35 # 方法二:
36 # pytorch自帶的優化器
37 data = np.array([5.0, 10.0]) # 給定初始數據x(0)
38 data = torch.tensor(data, requires_grad=True) # 需要求梯度
39 optimizer = optim.SGD([data], lr = rate) # 優化器:隨機梯度下降
40 plot_loss = []
41 for i in range(iteration):
42     optimizer.zero_grad() # 清空先前的梯度
43     loss = (data ** 2).sum() # 優化目標:最小化loss = x1^2 + x2^2
44     print('%d.' %(i+1),
45           '數據:', np.around(data.detach().numpy(), num),
46           '\t損失函數:', np.around(loss.item(), num),
47           end=' '
48           )
49     loss.backward() # 計算當前梯度d(loss)/dx = 2 * x,並反向傳播
50     print('\t梯度:', np.around(data.grad.detach().numpy(), num)) # 打印梯度
51     optimizer.step() # 優化更新:x = x - learning_rate * (d(loss)/dx)
52     plot_loss.append([i+1, loss.item()]) # 保存每次迭代的loss
53 
54 plot_loss = np.array(plot_loss) # 將list轉換成numpy
55 # --------------------畫Loss曲線圖------------------------
56 plt.plot(plot_loss[:, 0], plot_loss[:, 1], color = 'red', ls = '-')
57 plt.xlabel('Iteration')
58 plt.ylabel('Loss')
59 plt.tight_layout()
60 plt.savefig('Loss.png', bbox_inches='tight', dpi=500)
61 plt.show()

2. 結果

D:\ProgramData\Anaconda3\python.exe "D:/Python code/2023.3 exercise/optimizer/optim_test.py"
1. 數據: [ 5. 10.]     損失函數: 125.0     梯度: [10. 20.]
2. 數據: [4. 8.]     損失函數: 80.0     梯度: [ 8. 16.]
3. 數據: [3.2 6.4]     損失函數: 51.2     梯度: [ 6.4 12.8]
4. 數據: [2.56 5.12]     損失函數: 32.768     梯度: [ 5.12 10.24]
5. 數據: [2.048 4.096]     損失函數: 20.972     梯度: [4.096 8.192]
6. 數據: [1.638 3.277]     損失函數: 13.422     梯度: [3.277 6.554]
7. 數據: [1.311 2.621]     損失函數: 8.59     梯度: [2.621 5.243]
8. 數據: [1.049 2.097]     損失函數: 5.498     梯度: [2.097 4.194]
9. 數據: [0.839 1.678]     損失函數: 3.518     梯度: [1.678 3.355]
10. 數據: [0.671 1.342]     損失函數: 2.252     梯度: [1.342 2.684]
11. 數據: [0.537 1.074]     損失函數: 1.441     梯度: [1.074 2.147]
12. 數據: [0.429 0.859]     損失函數: 0.922     梯度: [0.859 1.718]
13. 數據: [0.344 0.687]     損失函數: 0.59     梯度: [0.687 1.374]
14. 數據: [0.275 0.55 ]     損失函數: 0.378     梯度: [0.55 1.1 ]
15. 數據: [0.22 0.44]     損失函數: 0.242     梯度: [0.44 0.88]
16. 數據: [0.176 0.352]     損失函數: 0.155     梯度: [0.352 0.704]
17. 數據: [0.141 0.281]     損失函數: 0.099     梯度: [0.281 0.563]
18. 數據: [0.113 0.225]     損失函數: 0.063     梯度: [0.225 0.45 ]
19. 數據: [0.09 0.18]     損失函數: 0.041     梯度: [0.18 0.36]
20. 數據: [0.072 0.144]     損失函數: 0.026     梯度: [0.144 0.288]
21. 數據: [0.058 0.115]     損失函數: 0.017     梯度: [0.115 0.231]
22. 數據: [0.046 0.092]     損失函數: 0.011     梯度: [0.092 0.184]
23. 數據: [0.037 0.074]     損失函數: 0.007     梯度: [0.074 0.148]
24. 數據: [0.03  0.059]     損失函數: 0.004     梯度: [0.059 0.118]
25. 數據: [0.024 0.047]     損失函數: 0.003     梯度: [0.047 0.094]
26. 數據: [0.019 0.038]     損失函數: 0.002     梯度: [0.038 0.076]
27. 數據: [0.015 0.03 ]     損失函數: 0.001     梯度: [0.03 0.06]
28. 數據: [0.012 0.024]     損失函數: 0.001     梯度: [0.024 0.048]
29. 數據: [0.01  0.019]     損失函數: 0.0     梯度: [0.019 0.039]
30. 數據: [0.008 0.015]     損失函數: 0.0     梯度: [0.015 0.031]
----------------------------------------------------------------
1. 數據: [ 5. 10.]     損失函數: 125.0     梯度: [10. 20.]
2. 數據: [4. 8.]     損失函數: 80.0     梯度: [ 8. 16.]
3. 數據: [3.2 6.4]     損失函數: 51.2     梯度: [ 6.4 12.8]
4. 數據: [2.56 5.12]     損失函數: 32.768     梯度: [ 5.12 10.24]
5. 數據: [2.048 4.096]     損失函數: 20.972     梯度: [4.096 8.192]
6. 數據: [1.638 3.277]     損失函數: 13.422     梯度: [3.277 6.554]
7. 數據: [1.311 2.621]     損失函數: 8.59     梯度: [2.621 5.243]
8. 數據: [1.049 2.097]     損失函數: 5.498     梯度: [2.097 4.194]
9. 數據: [0.839 1.678]     損失函數: 3.518     梯度: [1.678 3.355]
10. 數據: [0.671 1.342]     損失函數: 2.252     梯度: [1.342 2.684]
11. 數據: [0.537 1.074]     損失函數: 1.441     梯度: [1.074 2.147]
12. 數據: [0.429 0.859]     損失函數: 0.922     梯度: [0.859 1.718]
13. 數據: [0.344 0.687]     損失函數: 0.59     梯度: [0.687 1.374]
14. 數據: [0.275 0.55 ]     損失函數: 0.378     梯度: [0.55 1.1 ]
15. 數據: [0.22 0.44]     損失函數: 0.242     梯度: [0.44 0.88]
16. 數據: [0.176 0.352]     損失函數: 0.155     梯度: [0.352 0.704]
17. 數據: [0.141 0.281]     損失函數: 0.099     梯度: [0.281 0.563]
18. 數據: [0.113 0.225]     損失函數: 0.063     梯度: [0.225 0.45 ]
19. 數據: [0.09 0.18]     損失函數: 0.041     梯度: [0.18 0.36]
20. 數據: [0.072 0.144]     損失函數: 0.026     梯度: [0.144 0.288]
21. 數據: [0.058 0.115]     損失函數: 0.017     梯度: [0.115 0.231]
22. 數據: [0.046 0.092]     損失函數: 0.011     梯度: [0.092 0.184]
23. 數據: [0.037 0.074]     損失函數: 0.007     梯度: [0.074 0.148]
24. 數據: [0.03  0.059]     損失函數: 0.004     梯度: [0.059 0.118]
25. 數據: [0.024 0.047]     損失函數: 0.003     梯度: [0.047 0.094]
26. 數據: [0.019 0.038]     損失函數: 0.002     梯度: [0.038 0.076]
27. 數據: [0.015 0.03 ]     損失函數: 0.001     梯度: [0.03 0.06]
28. 數據: [0.012 0.024]     損失函數: 0.001     梯度: [0.024 0.048]
29. 數據: [0.01  0.019]     損失函數: 0.0     梯度: [0.019 0.039]
30. 數據: [0.008 0.015]     損失函數: 0.0     梯度: [0.015 0.031]

Process finished with exit code 0

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章