【深度学习 走开tensorflow2.0】TensorFlow 2.0 常用模块tf.TensorArray

无意中发现了一个巨牛的人工智能教程,忍不住分享一下给大家。教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家。点这里可以跳转到教程。人工智能教程

在部分网络结构,尤其是涉及到时间序列的结构中,我们可能需要将一系列张量以数组的方式依次存放起来,以供进一步处理。当然,在 Eager Execution 下,你可以直接使用一个 Python 列表(List)存放数组。不过,如果你需要基于计算图的特性(例如使用 @tf.function 加速模型运行或者使用 SavedModel 导出模型),就无法使用这种方式了。因此,TensorFlow 提供了 tf.TensorArray ,一种支持计算图特性的 TensorFlow 动态数组。下面将介绍tensorflow2.0在计算图模式中使用动态数组保存和读取张量的方法。

其声明的方式为:
arr = tf.TensorArray(dtype, size, dynamic_size=False) :声明一个大小为size,类型为dtype的 TensorArrayarr。如果将dynamic_size参数设置为True,则该数组会自动增长空间。
其读取和写入的方法为:
write(index, value) :将 value 写入数组的第 index 个位置;
read(index) :读取数组的第 index 个值;

请注意,由于需要支持计算图,tf.TensorArray的write()方法是不可以忽略左值的!也就是说,在 Graph Execution 模式下,必须按照以下的形式写入数组:

正确写法:
arr = arr.write(index, value)

错误写法:
arr.write(index, value)     # 生成的计算图操作没有左值接收,从而丢失

一个例子:

import tensorflow as tf

@tf.function
def array_write_and_read():
    arr = tf.TensorArray(dtype=tf.float32, size=3)
    arr = arr.write(0, tf.constant(0.0))
    arr = arr.write(1, tf.constant(1.0))
    arr = arr.write(2, tf.constant(2.0))
    arr_0 = arr.read(0)
    arr_1 = arr.read(1)
    arr_2 = arr.read(2)
    return arr_0, arr_1, arr_2

a, b, c = array_write_and_read()
print(a, b, c)

运行结果:

tf.Tensor(0.0, shape=(), dtype=float32) tf.Tensor(1.0, shape=(), dtype=float32) tf.Tensor(2.0, shape=(), dtype=float32)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章