pytorch model.apply lambda setattr

我在看SLIMMABLE NEURAL NETWORKS模型代码的时候,被他的这句代码给绕懵了

model.apply(lambda m: setattr(m, 'width_mult', width_mult))

他同时包含了apply、lambda、setattr几个我不懂的点,所以我费了很多时间才搞懂了这句话的意思。他其实等价于下面这段代码

def fn(m):
    setattr(m, 'width_mult', width_mult) 
model.apply(fn)

1. lambda

简单来说,lambda是一个匿名函数,它的一个最简单的例子是:

lambda x:x+1(1)
#等价于
def g(x):
    return x+1

所以

lambda m: setattr(m, 'width_mult', width_mult)
#等价于
def fn(m):
    setattr(m, 'width_mult', width_mult) 

2. setattr

setattr(object, name, value) 作用是给对象object的属性name赋值value

3. apply

model.apply方法会逐个遍历model的子模块

 

def fn(m):
    setattr(m, 'width_mult', width_mult) 
model.apply(fn)

所以综合起来,上面的代码的意思就是,逐个遍历model的子模块,给子模块中的变量width_mult赋值。(部分子模块有这个变量,部分没有)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章