Tensorflow——tf.cond()函數詳解

最近在閱讀大佬代碼的時候遇到了一個比較令人困惑的函數tf.cond( )控制數據流向),這裏就拿出來和大家詳解一下。TensorFlow 提供了幾個操作和類可以用來控制操作的執行並向圖中添加條件依賴關係,比如說tf.count_up_to()(對ref進行遞增直到limit)和tf.case()(創建案例case)等,tf.cond()就是其中的一種。

我們先來看一下tf.cond()在官方文檔的定義(其實就已經比較好理解了,有些像if…else的感覺,控制數據流現在或者延遲流向下一個操作,個人理解,雖然true_fn和false_fn都已經被執行了但tf.cond()可以控制哪個操作可以向下執行):

# 如果pred爲True,那麼則返回true_fn,否則則返回false_fn

cond ( 
    pred ,                 # 標量決定是否返回 true_fn 或 false_fn 結果(要是布爾型)
    true_fn = None ,       # 如果 pred 爲 true, 則被調用
    false_fn = None ,      # 如果 pred 爲 false, 則被調用
    strict = False ,       # 啓用/禁用 “嚴格”模式的布爾值
    name = None ,          # 返回的張量的可選名稱前綴
    fn1 = None , 
    fn2 = None
 )

官方也舉了個栗子🌰,也比較好了解,這裏也放在這裏供大家參考(如果像x<y,則執行tf.add(x,z),否則執行tf.square(y)):

import tensorflow as tf

# 官方文檔:
z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))

# 舉個栗子:
a=tf.constant(2)    
b=tf.constant(3)    
x=tf.constant(4)    
y=tf.constant(5)    
z = tf.multiply(a, b)    
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))

sess = tf.Session()
print(sess.run(tf.add(x,z)))       # 輸出結果:10
print(sess.run(tf.square(y)))      # 輸出結果:25
print(sess.run(result))            # 輸出結果:10 原因:x<y is True

一般在代碼中也會直接用布爾型的數組來指定數據的流向,舉個栗子🌰:

import tensorflow as tf

k = tf.placeholder_with_default([False,True],shape=[2],name='shortcut')
cw_1 = tf.cond(k[0],    # False,執行第二個lambda
				lambda:tf.constant(100),
				lambda:tf.cond(k[1],    # True,執行第一個lambda
								lambda:tf.constant(256),
								lambda:tf.constant(512)))

sess = tf.Session()
print(sess.run(cw_1))       # 輸出結果:256
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章