1. 操作說明
有3個結構相同但是weights不同的model組成一個list,models=[model1,model2,model3],還有一箇中心模型fl_model,這四個模型的結構和超參數都相同。
需要進行這樣一種操作:平均models裏面三個模型的weights,把平均之後的weights"賦值"給fl_model的weights。
2.代碼
在tensorflow裏可以直接用model.get_weights()和model.set_weights()來做,比較直觀和方便。感覺pytorch裏面稍微複雜一些。進行上述操作的代碼如下:
worker_state_dict=[x.state_dict() for x in models]
weight_keys=list(worker_state_dict[0].keys())
fed_state_dict=collections.OrderedDict()
for key in weight_keys:
key_sum=0
for i in range(len(models)):
key_sum=key_sum+worker_state_dict[i][key]
fed_state_dict[key]=key_sum/len(models)
#### update fed weights to fl model
fl_model.load_state_dict(fed_state_dict)