Tensorflow 的 saver

Tensorflow 有一個用來存儲Tensorflow.Variable 的類 saver,把變量存在文件裏,稱爲checkpoint文件。

初始化一個saver: sav = Tensorflow.train.Saver()

保存變量使用:sav.save(  se ,  checkPATH + checkFILE , global_step = i) 

checkPATH 是保存checkpoint文件的目錄名。checkFILE是checkpoint文件的前綴,最後生成的文件全名是checkFILE-{gloabl_step值}。global_step按照輪數排序。se 是tensoflow.Session() 得到的變量。

恢復變量: sav.restore(se ,checkpoint文件全路徑名(不含後綴) )

得到最新的checkpoint文件全路徑名(不含後綴):checkpoint文件全路徑名 = tensorflow.train.get_checkpoint_state( checkPATH )

必須要指出的是,保存(save) 和 回覆(restore)不能在一個應用裏串行執行。否則的話 ,restore時會出錯。

下面是源代碼:

import tensorflow as tf

checkPATH =  'model/' #save的目錄名
checkFILE = 'saverTest' #save 數據文件的前綴


def saveVariables() :
    print( "start save" ) 
    a = tf.Variable( [12,102,10002] ) #第一個需要save的變量
    b = tf.Variable(2*10)
    sav = tf.train.Saver( )
    se = tf .Session()
    init =tf.global_variables_initializer()
    se.run( init )
    for i in range( 3 ) :
        r = se.run( a )
        a = a*2
        print( r )
        sav.save(  se ,  checkPATH + checkFILE , global_step = i)
    se.close()
    


def restoreVariables() :
    print("restore strat" ) 
    va1 = tf.Variable( [0,0,00] )#第一個需要restore的變量,
                                             #名字可以與save時的不一樣,但是類型要一樣
    #b = tf.Variable(0) 
    sav = tf.train.Saver(  )
    se = tf .Session()
    init =tf.global_variables_initializer()
    se.run( init )
    ckpt = tf.train.get_checkpoint_state( checkPATH )
    print( "ckpt:" ,  ckpt )
    print (" path:" ,ckpt.model_checkpoint_path ) 
    if ckpt and ckpt.model_checkpoint_path:
        sav.restore(se , ckpt.model_checkpoint_path)
               #ckpt.model_checkpoint_path 是最新的存儲文件名
    r = se.run( va1 )
    print( r )
    #print( se.run(b) )
    se.close() 


#save 和 restore 要分開不同的應用執行。在同一個應用裏先後串行執行,restore時會失敗   

#saveVariables() 
restoreVariables()


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章