【深度學習 走開tensorflow2.0】TensorFlow 2.0 常用模塊tf.TensorArray

無意中發現了一個巨牛的人工智能教程,忍不住分享一下給大家。教程不僅是零基礎,通俗易懂,而且非常風趣幽默,像看小說一樣!覺得太牛了,所以分享給大家。點這裏可以跳轉到教程。人工智能教程

在部分網絡結構,尤其是涉及到時間序列的結構中,我們可能需要將一系列張量以數組的方式依次存放起來,以供進一步處理。當然,在 Eager Execution 下,你可以直接使用一個 Python 列表(List)存放數組。不過,如果你需要基於計算圖的特性(例如使用 @tf.function 加速模型運行或者使用 SavedModel 導出模型),就無法使用這種方式了。因此,TensorFlow 提供了 tf.TensorArray ,一種支持計算圖特性的 TensorFlow 動態數組。下面將介紹tensorflow2.0在計算圖模式中使用動態數組保存和讀取張量的方法。

其聲明的方式爲:
arr = tf.TensorArray(dtype, size, dynamic_size=False) :聲明一個大小爲size,類型爲dtype的 TensorArrayarr。如果將dynamic_size參數設置爲True,則該數組會自動增長空間。
其讀取和寫入的方法爲:
write(index, value) :將 value 寫入數組的第 index 個位置;
read(index) :讀取數組的第 index 個值;

請注意,由於需要支持計算圖,tf.TensorArray的write()方法是不可以忽略左值的!也就是說,在 Graph Execution 模式下,必須按照以下的形式寫入數組:

正確寫法:
arr = arr.write(index, value)

錯誤寫法:
arr.write(index, value)     # 生成的計算圖操作沒有左值接收,從而丟失

一個例子:

import tensorflow as tf

@tf.function
def array_write_and_read():
    arr = tf.TensorArray(dtype=tf.float32, size=3)
    arr = arr.write(0, tf.constant(0.0))
    arr = arr.write(1, tf.constant(1.0))
    arr = arr.write(2, tf.constant(2.0))
    arr_0 = arr.read(0)
    arr_1 = arr.read(1)
    arr_2 = arr.read(2)
    return arr_0, arr_1, arr_2

a, b, c = array_write_and_read()
print(a, b, c)

運行結果:

tf.Tensor(0.0, shape=(), dtype=float32) tf.Tensor(1.0, shape=(), dtype=float32) tf.Tensor(2.0, shape=(), dtype=float32)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章