tf.expand_dims()函數解析(最清晰的解釋)

歡迎關注WX公衆號:【程序員管小亮】

tf.expand_dims()函數用於給函數增加維度。

tf.expand_dims(
    input,
    axis=None,
    name=None,
    dim=None
)

參數:

  • input是輸入張量。

  • axis是指定擴大輸入張量形狀的維度索引值。

  • dim等同於軸,一般不推薦使用。

函數的功能是在給定一個input時,在axis軸處給input增加一個維度。

axis:

給定張量輸入input,此操作爲選擇維度索引值,在輸入形狀的維度索引值的軸處插入1的維度。 維度索引值的軸從零開始; 如果您指定軸是負數,則從最後向後進行計數,也就是倒數。

import tensorflow as tf

# 't' is a tensor of shape [2]
t = tf.constant([1,2])
print(t.shape)
t1 = tf.expand_dims(t, 0)
print(t1.shape)
t2 = tf.expand_dims(t, 1)
print(t2.shape)
t3 = tf.expand_dims(t, 1)
print(t3.shape)

> (2,)
> (1, 2)
> (2, 1)
> (2, 1)
import tensorflow as tf
import numpy as np

# 't2' is a tensor of shape [2, 3, 5]
t2 = np.zeros((2,3,5))
print(t2.shape)
t3 = tf.expand_dims(t2, 0)
t4 = tf.expand_dims(t2, 2)
t5 = tf.expand_dims(t2, 3)
print(t3.shape)
print(t4.shape)
print(t5.shape)

> (2, 3, 5)
> (1, 2, 3, 5)
> (2, 3, 1, 5)
> (2, 3, 5, 1)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章