最近在阅读大佬代码的时候遇到了一个比较令人困惑的函数tf.cond( )
(控制数据流向),这里就拿出来和大家详解一下。TensorFlow 提供了几个操作和类可以用来控制操作的执行并向图中添加条件依赖关系,比如说tf.count_up_to()
(对ref进行递增直到limit)和tf.case()
(创建案例case)等,tf.cond()就是其中的一种。
我们先来看一下tf.cond()在官方文档的定义(其实就已经比较好理解了,有些像if…else的感觉,控制数据流现在或者延迟流向下一个操作,个人理解,虽然true_fn和false_fn都已经被执行了但tf.cond()可以控制哪个操作可以向下执行):
# 如果pred为True,那么则返回true_fn,否则则返回false_fn
cond (
pred , # 标量决定是否返回 true_fn 或 false_fn 结果(要是布尔型)
true_fn = None , # 如果 pred 为 true, 则被调用
false_fn = None , # 如果 pred 为 false, 则被调用
strict = False , # 启用/禁用 “严格”模式的布尔值
name = None , # 返回的张量的可选名称前缀
fn1 = None ,
fn2 = None
)
官方也举了个栗子🌰,也比较好了解,这里也放在这里供大家参考(如果像x<y,则执行tf.add(x,z),否则执行tf.square(y)):
import tensorflow as tf
# 官方文档:
z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
# 举个栗子:
a=tf.constant(2)
b=tf.constant(3)
x=tf.constant(4)
y=tf.constant(5)
z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
sess = tf.Session()
print(sess.run(tf.add(x,z))) # 输出结果:10
print(sess.run(tf.square(y))) # 输出结果:25
print(sess.run(result)) # 输出结果:10 原因:x<y is True
一般在代码中也会直接用布尔型的数组来指定数据的流向,举个栗子🌰:
import tensorflow as tf
k = tf.placeholder_with_default([False,True],shape=[2],name='shortcut')
cw_1 = tf.cond(k[0], # False,执行第二个lambda
lambda:tf.constant(100),
lambda:tf.cond(k[1], # True,执行第一个lambda
lambda:tf.constant(256),
lambda:tf.constant(512)))
sess = tf.Session()
print(sess.run(cw_1)) # 输出结果:256