pytorch model.apply lambda setattr

我在看SLIMMABLE NEURAL NETWORKS模型代碼的時候,被他的這句代碼給繞懵了

model.apply(lambda m: setattr(m, 'width_mult', width_mult))

他同時包含了apply、lambda、setattr幾個我不懂的點,所以我費了很多時間才搞懂了這句話的意思。他其實等價於下面這段代碼

def fn(m):
    setattr(m, 'width_mult', width_mult) 
model.apply(fn)

1. lambda

簡單來說,lambda是一個匿名函數,它的一個最簡單的例子是:

lambda x:x+1(1)
#等價於
def g(x):
    return x+1

所以

lambda m: setattr(m, 'width_mult', width_mult)
#等價於
def fn(m):
    setattr(m, 'width_mult', width_mult) 

2. setattr

setattr(object, name, value) 作用是給對象object的屬性name賦值value

3. apply

model.apply方法會逐個遍歷model的子模塊

 

def fn(m):
    setattr(m, 'width_mult', width_mult) 
model.apply(fn)

所以綜合起來,上面的代碼的意思就是,逐個遍歷model的子模塊,給子模塊中的變量width_mult賦值。(部分子模塊有這個變量,部分沒有)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章