【深度學習 走進tensorflow2.0】TensorFlow 2.0 常用模塊@tf.function

1、@tf.function 是幹什麼用的?
雖然默認的 Eager Execution 爲我們帶來了靈活及易調試的特性,但在特定的場合(例如追求高性能或部署模型)時,我們依然希望使用圖模式,將模型轉換爲 TensorFlow 圖模型。此時,TensorFlow 2.0 爲我們提供了 tf.function 模塊,結合 AutoGraph 機制,使得我們僅需加入一個簡單的 @tf.function 修飾符,就能輕鬆將模型以圖模式運行!是不是很方便呢,下面我們具體看看如何使用它。

2、@tf.function 怎麼使用?

在 TensorFlow 2.0 中,推薦使用 @tf.function (而非 1.X 中的 tf.Session )實現 Graph Execution,從而將模型轉換爲易於部署且高性能的 TensorFlow 圖模型。只需要將我們希望以 Graph Execution 模式運行的代碼封裝在一個函數內,並在函數前加上 @tf.function 即可。

警告提醒:
並不是任何函數都可以被 @tf.function 修飾!@tf.function 使用靜態編譯將函數內的代碼轉換成計算圖,因此對函數內可使用的語句有一定限制(僅支持 Python 語言的一個子集),且需要函數內的操作本身能夠被構建爲計算圖。建議在函數內只使用 TensorFlow 的原生操作,不要使用過於複雜的 Python 語句,函數參數只包括 TensorFlow 張量或 NumPy 數組,並最好是能夠按照計算圖的思想去構建函數(換言之,@tf.function 只是給了你一種更方便的寫計算圖的方法,而不是萬能的)


import tensorflow as tf
import time
from zh.model.mnist.cnn import CNN
from zh.model.utils import MNISTLoader

num_batches = 400
batch_size = 50
learning_rate = 0.001
data_loader = MNISTLoader()

@tf.function
def train_one_step(X, y):
    with tf.GradientTape() as tape:
        y_pred = model(X)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
        loss = tf.reduce_mean(loss)
        # 注意這裏使用了TensorFlow內置的tf.print()。@tf.function不支持Python內置的print方法
        tf.print("loss", loss)
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))

if __name__ == '__main__':
    model = CNN()
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    start_time = time.time()
    for batch_index in range(num_batches):
        X, y = data_loader.get_batch(batch_size)
        train_one_step(X, y)
    end_time = time.time()
    print(end_time - start_time)

3、@tf.function 內在機制 怎麼樣?

當被 @tf.function 修飾的函數第一次被調用的時候,進行以下操作:
在 Eager Execution 模式關閉的環境下,函數內的代碼依次運行。也就是說,每個 tf. 方法都只是定義了計算節點,而並沒有進行任何實質的計算。這與 TensorFlow 1.X 的 Graph Execution 是一致的;
使用 AutoGraph 將函數中的 Python 控制流語句轉換成 TensorFlow 計算圖中的對應節點(比如說 while 和 for 語句轉換爲 tf.while , if 語句轉換爲 tf.cond 等等;
基於上面的兩步,建立函數內代碼的計算圖表示(爲了保證圖的計算順序,圖中還會自動加入一些 tf.control_dependencies 節點);
運行一次這個計算圖;
基於函數的名字和輸入的函數參數的類型生成一個哈希值,並將建立的計算圖緩存到一個哈希表中。

在被 @tf.function 修飾的函數之後再次被調用的時候,根據函數名和輸入的函數參數的類型計算哈希值,檢查哈希表中是否已經有了對應計算圖的緩存。如果是,則直接使用已緩存的計算圖,否則重新按上述步驟建立計算圖。

4、AutoGraph:將 Python 控制流轉換爲 TensorFlow 計算圖

前面提到,@tf.function 使用名爲 AutoGraph 的機制將函數中的 Python 控制流語句轉換成 TensorFlow 計算圖中的對應節點。以下是一個示例,使用 tf.autograph 模塊的低層 API tf.autograph.to_code 將函數 square_if_positive 轉換成 TensorFlow 計算圖:

import tensorflow as tf

@tf.function
def square_if_positive(x):
    if x > 0:
        x = x * x
    else:
        x = 0
    return x

a = tf.constant(1)
b = tf.constant(-1)
print(square_if_positive(a), square_if_positive(b))
print(tf.autograph.to_code(square_if_positive.python_function))
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章