使用autograph需要注意的地方
- 使用tf自帶的函數,而不是原始的Python函數。
因爲python中的函數僅僅會在跟蹤執行函數以創建靜態圖的階段使用,普通python函數無法嵌入到靜態圖中,所以在計算圖構建好之後再次調用的時候,這些python函數並沒有被計算,使用python函數會導致被裝飾器裝飾前(eager執行)和被裝飾器修飾後(靜態圖執行)的輸出不一致。
例如:
被@tf.function修飾過的函數利用np生成隨機數每一次都是一樣的,而使用tf就是不一樣的。
import tensorflow as tf
import numpy as np
@tf.function
def np_random():
a = np.random.randn(3, 3)
tf.print(a)
@tf.function
def tf_random():
b = tf.random.normal((3, 3))
tf.print(b)
# 被修飾過的np隨機數每一次都是一樣的
np_random()
np_random()
# 被修飾過的tf隨機數每一次不一樣
tf_random()
tf_random()
輸出:
array([[-0.15504732, 1.55624863, 0.05300533],
[ 0.31764741, -0.74893782, -1.09177159],
[ 0.99932818, 0.45183048, 0.52071062]])
array([[-0.15504732, 1.55624863, 0.05300533],
[ 0.31764741, -0.74893782, -1.09177159],
[ 0.99932818, 0.45183048, 0.52071062]])
[[-0.477426946 0.0352273062 1.25906813]
[-1.60776 -0.919097424 0.199720368]
[-0.819047332 -0.321983 -1.48350918]]
[[0.992117703 -0.778428 0.158320323]
[1.04978251 -0.56633848 -0.407015771]
[1.45319021 0.473201066 0.672261238]]
- 避免在@tf.function修飾的函數內部定義tf.Variable
如果函數內部定義了tf.Variable,那麼在eager執行的時候,這種創建變量的行爲在每次函數調用的時候都會發生,但是在靜態圖執行的時候,這種創建變量的行爲只會發生在第一步跟蹤python代碼邏輯創建計算圖時。這會導致修飾前後輸出結果不一致
如:
在函數內部定義tf.Variable然後報錯
# 在修飾外部創建
x = tf.Variable(1.0, dtype=tf.float32)
@tf.function
def out():
x.assign_add(1.0)
tf.print(x)
out()
out()
@tf.function
def inside():
a = tf.Variable(1.0, dtype=tf.float32)
a.assign_add(1.0)
tf.print(a)
# 報錯
# inside()
# inside()
- 被@tf.function修飾的函數不可修改該函數外部的python列表或者字典等結構類型變量
靜態計算圖是被編譯成c++代碼在tensorflow內核中執行的,python中的列表和字典等數據結構變量時無法嵌入到計算圖中的,他們僅僅能夠在創建計算圖的時候被讀取,在執行計算圖的時候無法修改Python中的列表和字典這樣的數據結構變量的"""
lyst = []
@tf.function
def append_list(x):
lyst.append(x)
tf.print(lyst)
append_list(tf.constant(4.3))
append_list(tf.constant(2.3))
print(lyst)
輸出:
[<tf.Tensor 'x:0' shape=() dtype=float32>]
[4.3]
[2.3]
autograph的工作機制
工作機制:
- 創建計算圖
- 執行計算圖
第一次輸入兩個參數都爲str的數據
@tf.function(autograph=True)
def add(a, b):
for i in tf.range(4):
tf.print(i)
c = a + b
print('finish')
return c
# 第一次輸入:
add(tf.constant('hello'), tf.constant('world'))
輸出:
finish
0
1
2
3
當再次使用相同的輸入參數類型調用這個被@tf.function裝飾的函數的時,只會發生一件事情,那就是執行上面步驟的第二步,執行計算圖。所以沒有看見打印’finish‘的結果
add(tf.constant('td'), tf.constant('new_world'))
輸出:
0
1
2
3
當使用不同的類型參數作爲輸入的時候,會重新創建一個計算圖,由於參數的類型發生了變化,已經創建的計算圖不能再次使用
add(tf.constant(3.4), tf.constant(4.4))
輸出:
finish
0
1
2
3
當輸入的不是tensor類型的時候,每一次都會重新創建計算圖,所以建議調用@tf.function的時候傳入tensor類型的數據。例如
add('td', 'new_world')
輸出:
finish
0
1
2
3