tensorflow學習筆記--提取張量的某一維數據

2019.10.10

提取張量的某一維

 

一、要點總結

最好使用tf.slice函數,可以保持提取出的切片的維度,如果採用類似list的方法,維度會降低,具體可看代碼

 

二、實驗代碼

場景:假設我有一個3x3的張量,我希望提取出其第一列,然後每一列都減去第一列的值

import tensorflow as tf
import numpy as np

a = np.array(range(9))
x = tf.constant(a)
x = tf.reshape(x,[3,3])
z = x[:,0]
y = x - z
with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	print(sess.run(z))
	print(sess.run(x))
	print(sess.run(y))

如果採用上述代碼,則輸出爲

[0 3 6]
[[0 1 2]
 [3 4 5]
 [6 7 8]]
[[ 0 -2 -4]
 [ 3  1 -1]
 [ 6  4  2]]

即雖然正確提取出了張量的值,但變成了1維的向量,再根據broadcast的規則,無法達成我期望的操作,即每一列都減去第一列

import tensorflow as tf
import numpy as np

a = np.array(range(9))
x = tf.constant(a)
x = tf.reshape(x,[3,3])
z = tf.slice(x,[0,0],[-1,1])
y = x - z
with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	print(sess.run(z))
	print(sess.run(x))
	print(sess.run(y))

換用tf.slice函數,則可以正確實現

[[0]
 [3]
 [6]]
[[0 1 2]
 [3 4 5]
 [6 7 8]]
[[0 1 2]
 [0 1 2]
 [0 1 2]]

關於tf.slice的一些說明:

1.slice函數必須有的3個輸入是input,begin和size

input是等待被提取的張量

begin是每一維我希望從什麼位置開始提取

size是每一維我希望提取出多少,-1表示全部提取

2.更詳細的內容參考

中文:https://www.w3cschool.cn/tensorflow_python/tensorflow_python-cdj92kbd.html

官方:https://tensorflow.google.cn/api_docs/python/tf/slice

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章