TL;DR
tf.Print(input,data)
a=tf.Print(a,["value",a,"shape",tf.shape(a)])
tensor a可以定義在代碼的任意一個位置,只要在session.run時節點a有數據流過(否則你也不會想要debug它),data就會被打印到終端。
Code
import tensorflow as tf
def _test_():
with tf.variable_scope('ts'):
a=tf.constant(1)
a=tf.Print(a,["value",a,"shape",tf.shape(a)])
return a
import tensorflow as tf
from TryDebug import _test_
if __name__ == '__main__':
with tf.Session() as sess:
print(sess.run(_test_()))
Play with it!
Some Detail
- 信息被打印到標準錯誤流,而非標準輸出流,在pycharm裏,前者紅色,後者白色
- tf.Print是一個tensor操作,這意味着他也會在圖上添加一個節點,這個節點的功能和tf.identity類似,僅傳遞數據,打印錯誤輸出是它的副作用。
- data的輸入必須是string或者tensor
- 實際上,tf.Print的input和data並不一定匹配,input的作用僅僅是將節點添加到圖上,而data可以是任何有數據流過的節點。
Note:這篇文章適用於tensorflow2.0的靜態計算圖+會話機制,對於2.0新增加的keras模式和動態計算圖,我尚未學習,也不確定是否有更好的辦法!