該項目中把tf的數據存儲和讀取抽取出兩個函數,方便開發,思想和代碼值得借遷
一.存儲
def save_variables(save_path, variables=None, sess=None):
import joblib
sess = sess or get_session()
variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
ps = sess.run(variables)
save_dict = {v.name: value for v, value in zip(variables, ps)}
dirname = os.path.dirname(save_path)
if any(dirname):
os.makedirs(dirname, exist_ok=True)
joblib.dump(save_dict, save_path)
- 第一次見 or 這樣寫,意思就是前一個不是None或者0,就取前一個,否則取後一個。
- tf裏,一個session就保存了各種訓練的數據和計算圖,所依直接把sess傳過來,從tf自帶的tf.GraphKeys.GLOBAL_VARIABLES取出其中的全局變量名。然後run()一下就能得到參數值,再放入一個字典容器
- 根據路徑存入joblib裏面
- 其中joblib是sklearn中的一個專門用於保存訓練的模型的
不知道的點這裏
二.加載
def load_variables(load_path, variables=None, sess=None):
import joblib
sess = sess or get_session()
variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
loaded_params = joblib.load(os.path.expanduser(load_path))
restores = []
if isinstance(loaded_params, list):
assert len(loaded_params) == len(variables), 'number of variables loaded mismatches len(variables)'
for d, v in zip(loaded_params, variables):
restores.append(v.assign(d))
else:
for v in variables:
restores.append(v.assign(loaded_params[v.name]))
sess.run(restores)
跟上面相反,看代碼就能明白
- assign()是tf裏的賦值函數,注意tf裏的操作寫完都要run()才能生效,不然它僅僅是圖上的一個結點
- isinstance()是python裏比較兩個對象是否相同,它具有繼承關係,也就是說如果他的父類相同,也算同一類