tensorflow2.x的autograph

使用autograph需要注意的地方

  1. 使用tf自帶的函數,而不是原始的Python函數。

因爲python中的函數僅僅會在跟蹤執行函數以創建靜態圖的階段使用,普通python函數無法嵌入到靜態圖中,所以在計算圖構建好之後再次調用的時候,這些python函數並沒有被計算,使用python函數會導致被裝飾器裝飾前(eager執行)和被裝飾器修飾後(靜態圖執行)的輸出不一致。

例如:
被@tf.function修飾過的函數利用np生成隨機數每一次都是一樣的,而使用tf就是不一樣的。

import tensorflow as tf
import numpy as np


@tf.function
def np_random():
    a = np.random.randn(3, 3)
    tf.print(a)


@tf.function
def tf_random():
    b = tf.random.normal((3, 3))
    tf.print(b)

# 被修飾過的np隨機數每一次都是一樣的
np_random()
np_random()
# 被修飾過的tf隨機數每一次不一樣
tf_random()
tf_random()

輸出:

array([[-0.15504732,  1.55624863,  0.05300533],
       [ 0.31764741, -0.74893782, -1.09177159],
       [ 0.99932818,  0.45183048,  0.52071062]])
array([[-0.15504732,  1.55624863,  0.05300533],
       [ 0.31764741, -0.74893782, -1.09177159],
       [ 0.99932818,  0.45183048,  0.52071062]])
[[-0.477426946 0.0352273062 1.25906813]
 [-1.60776 -0.919097424 0.199720368]
 [-0.819047332 -0.321983 -1.48350918]]
[[0.992117703 -0.778428 0.158320323]
 [1.04978251 -0.56633848 -0.407015771]
 [1.45319021 0.473201066 0.672261238]]
  1. 避免在@tf.function修飾的函數內部定義tf.Variable

如果函數內部定義了tf.Variable,那麼在eager執行的時候,這種創建變量的行爲在每次函數調用的時候都會發生,但是在靜態圖執行的時候,這種創建變量的行爲只會發生在第一步跟蹤python代碼邏輯創建計算圖時。這會導致修飾前後輸出結果不一致

如:
在函數內部定義tf.Variable然後報錯

# 在修飾外部創建
x = tf.Variable(1.0, dtype=tf.float32)
@tf.function
def out():
    x.assign_add(1.0)
    tf.print(x)

out()
out()

@tf.function
def inside():
    a = tf.Variable(1.0, dtype=tf.float32)
    a.assign_add(1.0)
    tf.print(a)

# 報錯
# inside()
# inside()
  1. 被@tf.function修飾的函數不可修改該函數外部的python列表或者字典等結構類型變量

靜態計算圖是被編譯成c++代碼在tensorflow內核中執行的,python中的列表和字典等數據結構變量時無法嵌入到計算圖中的,他們僅僅能夠在創建計算圖的時候被讀取,在執行計算圖的時候無法修改Python中的列表和字典這樣的數據結構變量的"""

lyst = []
@tf.function
def append_list(x):
    lyst.append(x)
    tf.print(lyst)

append_list(tf.constant(4.3))
append_list(tf.constant(2.3))
print(lyst)

輸出:

[<tf.Tensor 'x:0' shape=() dtype=float32>]
[4.3]
[2.3]

autograph的工作機制

工作機制:

  1. 創建計算圖
  2. 執行計算圖

第一次輸入兩個參數都爲str的數據

@tf.function(autograph=True)
def add(a, b):
    for i in tf.range(4):
        tf.print(i)
    c = a + b
    print('finish')
    return c

# 第一次輸入:
add(tf.constant('hello'), tf.constant('world'))

輸出:

finish
0
1
2
3

當再次使用相同的輸入參數類型調用這個被@tf.function裝飾的函數的時,只會發生一件事情,那就是執行上面步驟的第二步,執行計算圖。所以沒有看見打印’finish‘的結果

add(tf.constant('td'), tf.constant('new_world'))

輸出:

0
1
2
3

當使用不同的類型參數作爲輸入的時候,會重新創建一個計算圖,由於參數的類型發生了變化,已經創建的計算圖不能再次使用

add(tf.constant(3.4), tf.constant(4.4))

輸出:

finish
0
1
2
3

當輸入的不是tensor類型的時候,每一次都會重新創建計算圖,所以建議調用@tf.function的時候傳入tensor類型的數據。例如

add('td', 'new_world')

輸出:

finish
0
1
2
3
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章