tensorflow 加載參數,強化學習模型可擴展。

強化學習中想要在4個智能體訓練出來的模型擴展到更多智能體的模型。

若要可擴展,首先需要所有智能體共享一個模型參數,方法是設置reuse=True即可。

然而,現實操作可能還會存在問題。因此通常三個步驟來查看。

1、通常先輸出網絡的參數。下面三種方法稍有不同,具體輸出可以看出。

# 得到該網絡中,所有可以加載的參數的三種方法
variables = tf.contrib.framework.get_variables_to_restore()
for i in range(len(variables)):
    print(variables[i])
print('------------variables22222222222---------------')
variables2 = tf.all_variables()
for i in range(len(variables2)):
    print(variables2[i])
print('---------variables33333333------------------')
variables2 = tf.trainable_variables() #得到可訓練參數
for i in range(len(variables2)):
   print(variables2[i])

2、過濾一些無關緊要的參數

 # 刪除其他參數,adv_agent層中的參數
        variables_to_restore = [v for v in variables if v.name.split('/')[0] == 'adv_agent']

3、 加載model的(過濾後)這部分參數

# 構建這部分參數的saver
        saver = tf.train.Saver(variables_to_restore)
        saver.restore(U.get_session(),arglist.load_dir) #arglist.load_dir就是模型存儲的位置

參考鏈接:

1、https://blog.csdn.net/marsjhao/article/details/72829635#commentBox

2、 https://blog.csdn.net/huachao1001/article/details/78501928

3、https://blog.csdn.net/u011961856/article/details/77064631

4、https://blog.csdn.net/b876144622/article/details/79962727

5、https://blog.csdn.net/jeryjeryjery/article/details/79880475

6、https://blog.csdn.net/CV_YOU/article/details/80698942 (多次載入部分參數)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章