Tensorflow 有一個用來存儲Tensorflow.Variable 的類 saver,把變量存在文件裏,稱爲checkpoint文件。
初始化一個saver: sav = Tensorflow.train.Saver()
保存變量使用:sav.save( se , checkPATH + checkFILE , global_step = i)
checkPATH 是保存checkpoint文件的目錄名。checkFILE是checkpoint文件的前綴,最後生成的文件全名是checkFILE-{gloabl_step值}。global_step按照輪數排序。se 是tensoflow.Session() 得到的變量。
恢復變量: sav.restore(se ,checkpoint文件全路徑名(不含後綴) )
得到最新的checkpoint文件全路徑名(不含後綴):checkpoint文件全路徑名 = tensorflow.train.get_checkpoint_state( checkPATH )
必須要指出的是,保存(save) 和 回覆(restore)不能在一個應用裏串行執行。否則的話 ,restore時會出錯。
下面是源代碼:
import tensorflow as tf
checkPATH = 'model/' #save的目錄名
checkFILE = 'saverTest' #save 數據文件的前綴
def saveVariables() :
print( "start save" )
a = tf.Variable( [12,102,10002] ) #第一個需要save的變量
b = tf.Variable(2*10)
sav = tf.train.Saver( )
se = tf .Session()
init =tf.global_variables_initializer()
se.run( init )
for i in range( 3 ) :
r = se.run( a )
a = a*2
print( r )
sav.save( se , checkPATH + checkFILE , global_step = i)
se.close()
def restoreVariables() :
print("restore strat" )
va1 = tf.Variable( [0,0,00] )#第一個需要restore的變量,
#名字可以與save時的不一樣,但是類型要一樣
#b = tf.Variable(0)
sav = tf.train.Saver( )
se = tf .Session()
init =tf.global_variables_initializer()
se.run( init )
ckpt = tf.train.get_checkpoint_state( checkPATH )
print( "ckpt:" , ckpt )
print (" path:" ,ckpt.model_checkpoint_path )
if ckpt and ckpt.model_checkpoint_path:
sav.restore(se , ckpt.model_checkpoint_path)
#ckpt.model_checkpoint_path 是最新的存儲文件名
r = se.run( va1 )
print( r )
#print( se.run(b) )
se.close()
#save 和 restore 要分開不同的應用執行。在同一個應用裏先後串行執行,restore時會失敗
#saveVariables()
restoreVariables()