【Pytorch】optimizer原理

optimizer原理

【參考筆記】
【源碼鏈接
舉個栗子,定義一個全連接網絡:

import torch
import torch.nn as nn
import torch.optim as optim

net = nn.Linear(2, 2)
# 權重矩陣初始化爲1
nn.init.constant_(net.weight, val=100)
nn.init.constant_(net.bias, val=20)
optimizer = optim.SGD(net.parameters(), lr=0.01)

1. 測試optimizer有哪些屬性

print(optimizer.__dict__)

得到:

{'defaults': {'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}, 'state': defaultdict(<class 'dict'>, {}), 'param_groups': [{'params': [Parameter containing:
tensor([[100., 100.],
        [100., 100.]], requires_grad=True), Parameter containing:
tensor([20., 20.], requires_grad=True)], 'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]}

2. 測試optimizer的param_groups包含哪些參數

print(optimizer.param_groups)

得到:

[{'params': [Parameter containing:
tensor([[100., 100.],
        [100., 100.]], requires_grad=True), Parameter containing:
tensor([20., 20.], requires_grad=True)], 'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]

其中2x2的矩陣是net的權重矩陣,1x2爲偏置矩陣,其餘爲優化器的其它參數,所以說param_groups保存了優化器的全部數據,這個下面的state_dict()不同。

3. optimizer的狀態 state_dict()

參考下面源碼中對state_dict()的定義,可以看出state_dict()包含優化器狀態state和參數param_groups兩個參數

def state_dict(self):
    r"""Returns the state of the optimizer as a :class:`dict` """
    # Save ids instead of Tensors
    def pack_group(group):
        # 對"params"和其它的鍵採用不同規則
        packed = {k: v for k, v in group.items() if k != 'params'}
        # 這裏並沒有保存參數的值,而是保存參數的id
        packed['params'] = [id(p) for p in group['params']]
        return packed
    # 對self.param_groups進行遍歷
    param_groups = [pack_group(g) for g in self.param_groups]
    # Remap state to use ids as keys
    packed_state = {(id(k) if isinstance(k, torch.Tensor) else k): v
                    for k, v in self.state.items()}
    # 返回狀態和參數組,其中參數組纔是優化器的參數
    return {
        'state': packed_state,
        'param_groups': param_groups,
    }

打印優化器參數:

print(optimizer.state_dict()["param_groups"])

可以到優化器的完整參數如下:

[{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 
'nesterov': False, 'params': [2149749904224, 2149749906312]}]

打印優化器完整狀態(狀態+參數):

print(optimizer.state_dict())

可以到優化器的狀態如下:

{'state': {}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2423968124216, 2423968124360]}]}

保存優化器的完整狀態:

optimizer_old = optim.SGD(net.parameters(), lr=100) 
torch.save(optimizer_old.state_dict(), "optim_old.npy")

4. optimizer的load_state_dict()

恢復優化器的完整狀態:

optimizer_new = optim.SGD(net.parameters(), lr=0.01)
old_state = torch.load("optim_old.npy")
# 將之前定義的優化器參數給新的優化器
optimizer_new.load_state_dict(old_state)
print(optimizer_new.state_dict()["param_groups"])

5. optimizer的梯度清空zero_grad()

optimizer.zero_grad()源碼定義如下:

def zero_grad(self):
    r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
    # 獲取每一組參數
    for group in self.param_groups:
        # 遍歷當前參數組所有的params
        for p in group['params']:
            if p.grad is not None:
                p.grad.detach_()
                p.grad.zero_()

這個遍歷過程就是獲取optimizer的param_groups屬性的字典,之中的[“params”],之中的所有參數,通過遍歷設定每個參數的梯度值爲0。

6. optimizer的單步更新step()

直接看源碼:

def step(self, closure):
    r"""Performs a single optimization step (parameter update).
    Arguments:
        closure (callable): A closure that reevaluates the model and
            returns the loss. Optional for most optimizers.
    """
    raise NotImplementedError

優化器的step()函數負責更新參數值,但是其具體實現對於不同的優化算法是不同的,所以optimizer類只是定義了這種行爲,但是並沒有給出具體實現。

【其他參考資料】

發佈了162 篇原創文章 · 獲贊 67 · 訪問量 16萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章