tf.split()、tf.tile()函数的用法和例子

tf.split()

顾名思义就是将tensor分割成为列表的形式。通常tf.split之后往往会跟tf.concat结合使用。

tf.split(
    value,
    num_or_size_splits,
    axis=0,
    num=None,
    name='split'
)

value:准备切分的张量
num_or_size_splits:准备切成几份
axis : 准备在第几个维度上进行切割
其中分割方式分为两种

  1. 如果num_or_size_splits 传入的 是一个整数,那直接在axis=D这个维度上把张量平均切分成几个小张量
  2. 如果num_or_size_splits 传入的是一个向量(这里向量各个元素的和要跟原本这个维度的数值相等)就根据这个向量有几个元素分为几项)

举个例子:

# 张量为(530)
# 这个时候5是axis=030是axis=1,如果要在axis=1这个维度上把这个张量拆分成三个子张量
#传入向量时
split0, split1, split2 = tf.split(value, [4, 15, 11], 1)
tf.shape(split0)  # [5, 4]
tf.shape(split1)  # [5, 15]
tf.shape(split2)  # [5, 11]
# 传入整数时
split0, split1, split2 = tf.split(value, num_or_size_splits=3, axis=1)
tf.shape(split0)  # [5, 10]

tf.tile()

此函数的作用就是将tensor在某一维度上进行扩展,就是将原来的tensor复制多次,然后拼接在制定的axis上。

tf.tile(
    input, multiples, name=None
)

举个例子:
在这里插入图片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章