詳情可參考官方文檔
所有優化器都實現一種step()更新參數的方法。它可以以兩種方式使用:
optimizer.step()
這是大多數優化程序支持的簡化版本。一旦用來計算梯度,就可以調用該函數 backward()。
例:
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.step(closure)
一些優化算法(例如共軛梯度和LBFGS)需要多次重新評估函數,因此您必須傳遞一個閉包以允許它們重新計算模型。閉合應清除梯度,計算損耗,然後將其返回。
例:
for input, target in dataset:
def closure():
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
return loss
optimizer.step(closure)