【深度学习 走进tensorflow2.0】TensorFlow 2.0 常用模块@tf.function

1、@tf.function 是干什么用的?
虽然默认的 Eager Execution 为我们带来了灵活及易调试的特性,但在特定的场合(例如追求高性能或部署模型)时,我们依然希望使用图模式,将模型转换为 TensorFlow 图模型。此时,TensorFlow 2.0 为我们提供了 tf.function 模块,结合 AutoGraph 机制,使得我们仅需加入一个简单的 @tf.function 修饰符,就能轻松将模型以图模式运行!是不是很方便呢,下面我们具体看看如何使用它。

2、@tf.function 怎么使用?

在 TensorFlow 2.0 中,推荐使用 @tf.function (而非 1.X 中的 tf.Session )实现 Graph Execution,从而将模型转换为易于部署且高性能的 TensorFlow 图模型。只需要将我们希望以 Graph Execution 模式运行的代码封装在一个函数内,并在函数前加上 @tf.function 即可。

警告提醒:
并不是任何函数都可以被 @tf.function 修饰!@tf.function 使用静态编译将函数内的代码转换成计算图,因此对函数内可使用的语句有一定限制(仅支持 Python 语言的一个子集),且需要函数内的操作本身能够被构建为计算图。建议在函数内只使用 TensorFlow 的原生操作,不要使用过于复杂的 Python 语句,函数参数只包括 TensorFlow 张量或 NumPy 数组,并最好是能够按照计算图的思想去构建函数(换言之,@tf.function 只是给了你一种更方便的写计算图的方法,而不是万能的)


import tensorflow as tf
import time
from zh.model.mnist.cnn import CNN
from zh.model.utils import MNISTLoader

num_batches = 400
batch_size = 50
learning_rate = 0.001
data_loader = MNISTLoader()

@tf.function
def train_one_step(X, y):
    with tf.GradientTape() as tape:
        y_pred = model(X)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
        loss = tf.reduce_mean(loss)
        # 注意这里使用了TensorFlow内置的tf.print()。@tf.function不支持Python内置的print方法
        tf.print("loss", loss)
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))

if __name__ == '__main__':
    model = CNN()
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    start_time = time.time()
    for batch_index in range(num_batches):
        X, y = data_loader.get_batch(batch_size)
        train_one_step(X, y)
    end_time = time.time()
    print(end_time - start_time)

3、@tf.function 内在机制 怎么样?

当被 @tf.function 修饰的函数第一次被调用的时候,进行以下操作:
在 Eager Execution 模式关闭的环境下,函数内的代码依次运行。也就是说,每个 tf. 方法都只是定义了计算节点,而并没有进行任何实质的计算。这与 TensorFlow 1.X 的 Graph Execution 是一致的;
使用 AutoGraph 将函数中的 Python 控制流语句转换成 TensorFlow 计算图中的对应节点(比如说 while 和 for 语句转换为 tf.while , if 语句转换为 tf.cond 等等;
基于上面的两步,建立函数内代码的计算图表示(为了保证图的计算顺序,图中还会自动加入一些 tf.control_dependencies 节点);
运行一次这个计算图;
基于函数的名字和输入的函数参数的类型生成一个哈希值,并将建立的计算图缓存到一个哈希表中。

在被 @tf.function 修饰的函数之后再次被调用的时候,根据函数名和输入的函数参数的类型计算哈希值,检查哈希表中是否已经有了对应计算图的缓存。如果是,则直接使用已缓存的计算图,否则重新按上述步骤建立计算图。

4、AutoGraph:将 Python 控制流转换为 TensorFlow 计算图

前面提到,@tf.function 使用名为 AutoGraph 的机制将函数中的 Python 控制流语句转换成 TensorFlow 计算图中的对应节点。以下是一个示例,使用 tf.autograph 模块的低层 API tf.autograph.to_code 将函数 square_if_positive 转换成 TensorFlow 计算图:

import tensorflow as tf

@tf.function
def square_if_positive(x):
    if x > 0:
        x = x * x
    else:
        x = 0
    return x

a = tf.constant(1)
b = tf.constant(-1)
print(square_if_positive(a), square_if_positive(b))
print(tf.autograph.to_code(square_if_positive.python_function))
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章