pytorch實現多個模型的weights平均和修改weights

1. 操作說明

有3個結構相同但是weights不同的model組成一個list,models=[model1,model2,model3],還有一箇中心模型fl_model,這四個模型的結構和超參數都相同。

需要進行這樣一種操作:平均models裏面三個模型的weights,把平均之後的weights"賦值"給fl_model的weights。

2.代碼

在tensorflow裏可以直接用model.get_weights()和model.set_weights()來做,比較直觀和方便。感覺pytorch裏面稍微複雜一些。進行上述操作的代碼如下:

worker_state_dict=[x.state_dict() for x in models]
weight_keys=list(worker_state_dict[0].keys())
fed_state_dict=collections.OrderedDict()
for key in weight_keys:
    key_sum=0
    for i in range(len(models)):
        key_sum=key_sum+worker_state_dict[i][key]
    fed_state_dict[key]=key_sum/len(models)
#### update fed weights to fl model
fl_model.load_state_dict(fed_state_dict)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章