Tensorflow調試指南:tf.Print

TL;DR

tf.Print(input,data)

a=tf.Print(a,["value",a,"shape",tf.shape(a)])

tensor a可以定義在代碼的任意一個位置,只要在session.run時節點a有數據流過(否則你也不會想要debug它),data就會被打印到終端。

Code

import tensorflow as tf

def _test_():
    with tf.variable_scope('ts'):
        a=tf.constant(1)
        a=tf.Print(a,["value",a,"shape",tf.shape(a)])
    return a
import tensorflow as tf
from TryDebug import _test_
if __name__ == '__main__':
    with tf.Session() as sess:
        print(sess.run(_test_()))

Play with it!

Some Detail

  1. 信息被打印到標準錯誤流,而非標準輸出流,在pycharm裏,前者紅色,後者白色
    在這裏插入圖片描述
  2. tf.Print是一個tensor操作,這意味着他也會在圖上添加一個節點,這個節點的功能和tf.identity類似,僅傳遞數據,打印錯誤輸出是它的副作用。
  3. data的輸入必須是string或者tensor
  4. 實際上,tf.Print的input和data並不一定匹配,input的作用僅僅是將節點添加到圖上,而data可以是任何有數據流過的節點。

Note:這篇文章適用於tensorflow2.0的靜態計算圖+會話機制,對於2.0新增加的keras模式和動態計算圖,我尚未學習,也不確定是否有更好的辦法!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章