【NAS工具箱】Pytorch中的Buffer&Parameter

Parameter : 模型中的一種可以被反向傳播更新的參數。

第一種:

  • 直接通過成員變量nn.Parameter()進行創建,會自動註冊到parameter中。
def __init__(self):
    super(MyModel, self).__init__()
    self.param = nn.Parameter(torch.randn(3, 3))  # 模型的成員變量

或者:

  • 通過nn.Parameter() 創建普通對象
  • 通過register_parameter()進行註冊
  • 可以通過model.parameters()返回
def __init__(self):
    super(MyModel, self).__init__()
    param = nn.Parameter(torch.randn(3, 3))  # 普通 Parameter 對象
    self.register_parameter("param", param)

Buffer : 模型中不能被反向傳播算法更新的參數。

  • 創建tensor
  • 將tensor通過register_buffer進行註冊
  • 可以通過model.buffers()返回
def __init__(self):
    super(MyModel, self).__init__()
    buffer = torch.randn(2, 3)  # tensor
    self.register_buffer('my_buffer', buffer)
    self.param = nn.Parameter(torch.randn(3, 3))  # 模型的成員變量

總結:

  • 模型參數=parameter+buffer; optimizer只能更新parameter,不能更新buffer,buffer只能通過forward進行更新。
  • 模型保存的參數 model.state_dict() 返回一個OrderDict
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章