強化學習 ---baseline項目之 TensorFlow的訓練參數的存儲和加載

       該項目中把tf的數據存儲和讀取抽取出兩個函數,方便開發,思想和代碼值得借遷



一.存儲

def save_variables(save_path, variables=None, sess=None):
    import joblib
    sess = sess or get_session()
    variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

    ps = sess.run(variables)
    save_dict = {v.name: value for v, value in zip(variables, ps)}
    dirname = os.path.dirname(save_path)
    if any(dirname):
        os.makedirs(dirname, exist_ok=True)
    joblib.dump(save_dict, save_path)
  • 第一次見 or 這樣寫,意思就是前一個不是None或者0,就取前一個,否則取後一個。
  • tf裏,一個session就保存了各種訓練的數據和計算圖,所依直接把sess傳過來,從tf自帶的tf.GraphKeys.GLOBAL_VARIABLES取出其中的全局變量名。然後run()一下就能得到參數值,再放入一個字典容器
  • 根據路徑存入joblib裏面
  • 其中joblib是sklearn中的一個專門用於保存訓練的模型的
    不知道的點這裏

二.加載

def load_variables(load_path, variables=None, sess=None):
    import joblib
    sess = sess or get_session()
    variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

    loaded_params = joblib.load(os.path.expanduser(load_path))
    restores = []
    if isinstance(loaded_params, list):
        assert len(loaded_params) == len(variables), 'number of variables loaded mismatches len(variables)'
        for d, v in zip(loaded_params, variables):
            restores.append(v.assign(d))
    else:
        for v in variables:
            restores.append(v.assign(loaded_params[v.name]))

    sess.run(restores)

跟上面相反,看代碼就能明白

  • assign()是tf裏的賦值函數,注意tf裏的操作寫完都要run()才能生效,不然它僅僅是圖上的一個結點
  • isinstance()是python裏比較兩個對象是否相同,它具有繼承關係,也就是說如果他的父類相同,也算同一類
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章