2019.10.10
提取張量的某一維
一、要點總結
最好使用tf.slice函數,可以保持提取出的切片的維度,如果採用類似list的方法,維度會降低,具體可看代碼
二、實驗代碼
場景:假設我有一個3x3的張量,我希望提取出其第一列,然後每一列都減去第一列的值
import tensorflow as tf
import numpy as np
a = np.array(range(9))
x = tf.constant(a)
x = tf.reshape(x,[3,3])
z = x[:,0]
y = x - z
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(z))
print(sess.run(x))
print(sess.run(y))
如果採用上述代碼,則輸出爲
[0 3 6]
[[0 1 2]
[3 4 5]
[6 7 8]]
[[ 0 -2 -4]
[ 3 1 -1]
[ 6 4 2]]
即雖然正確提取出了張量的值,但變成了1維的向量,再根據broadcast的規則,無法達成我期望的操作,即每一列都減去第一列
import tensorflow as tf
import numpy as np
a = np.array(range(9))
x = tf.constant(a)
x = tf.reshape(x,[3,3])
z = tf.slice(x,[0,0],[-1,1])
y = x - z
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(z))
print(sess.run(x))
print(sess.run(y))
換用tf.slice函數,則可以正確實現
[[0]
[3]
[6]]
[[0 1 2]
[3 4 5]
[6 7 8]]
[[0 1 2]
[0 1 2]
[0 1 2]]
關於tf.slice的一些說明:
1.slice函數必須有的3個輸入是input,begin和size
input是等待被提取的張量
begin是每一維我希望從什麼位置開始提取
size是每一維我希望提取出多少,-1表示全部提取
2.更詳細的內容參考
中文:https://www.w3cschool.cn/tensorflow_python/tensorflow_python-cdj92kbd.html