concat and stack
tf.concat相當於numpy中的np.concatenate函數,用於將兩個張量在某一個維度(axis)合併起來
a=tf.reshape(np.arange(4),(2,2))
[[0 1]
[2 3]]
b=tf.reshape(np.arange(4,8),(2,2))
[[4 5]
[6 7]]
c=tf.concat([a,b],axis=0)
[[0 1]
[2 3]
[4 5]
[6 7]]
d=tf.concat([a,b],axis=1)
[[0 1 4 5]
[2 3 6 7]]
tf.concat拼接的是兩個shape完全相同的張量,並且產生的張量的維度不會發生變化,而tf.stack拼接後的張量的維度+1
tf.stack 的axis 值取值範圍爲 -(R+1)~(R+1)
a=tf.reshape(np.arange(4),(2,2))
b=tf.reshape(np.arange(4,8),(2,2))
#axis默認爲0
c=tf.stack([a,b])
[[[0 1]
[2 3]]
[[4 5]
[6 7]]]
#以下等價
d=tf.stack([a,b],axis=1)
d=tf.stack([a,b],axis=-2)
[[[0 1]
[4 5]]
[[2 3]
[6 7]]]
strack and transpos
x=np.arange(12).reshape((2,2,3))
[[[ 0 1 2]
[ 3 4 5]]
[[ 6 7 8]
[ 9 10 11]]]
#等價
print(np.stack(x,axis=0))
print(np.transpose(x,(0,1,2)))
[[[ 0 1 2]
[ 3 4 5]]
[[ 6 7 8]
[ 9 10 11]]]
#等價
print(np.stack(x,axis=1))
print(np.transpose(x,(1,0,2)))
[[[ 0 1 2]
[ 6 7 8]]
[[ 3 4 5]
[ 9 10 11]]]
#等價
print(np.stack(x,axis=2))
print(np.transpose(x,(1,2,0)))
[[[ 0 6]
[ 1 7]
[ 2 8]]
[[ 3 9]
[ 4 10]
[ 5 11]]]