1 源代碼註釋的解釋
- This operation extracts a slice of size
size
from a tensorinput
starting at the location specified bybegin
. The slicesize
is represented as tensor shape, wheresize[i]
is the number of elements of the 'i’th dimension ofinput
that you want to slice. The starting location (begin
) for the slice is represented as an offset in each dimension ofinput
. In other words,begin[i]
is the offset into the 'i’th dimension ofinput
that you want to slice from.
2 tf.slice函數的參數
slice(input, begin, size, name)
input:輸入的tensor變量。
begin:每個維度的起始位置。
size:從每個維度切取的大小。
3 實例解釋
測試代碼如下:
import tensorflow as tf
data = tf.constant([[[1, 2, 3], [11, 12, 13]], [[21, 22, 23], [31, 32, 33]], [[41, 42, 43], [51, 52, 53]], [[61, 62, 63], [71, 72, 73]]])
slice_ = tf.slice(data, [2, 1, 0], [1, 1, 3])
with tf.Session() as sess:
data, slice_ = sess.run([data, slice_])
print('原始大小:\n', data.shape)
print('原始數據:\n',data)
print('切取後大小:\n', slice_.shape)
print('切取後數據:\n',slice_)
結果輸出如下:
原始大小:
(4, 2, 3)
原始數據:
[[[ 1 2 3]
[11 12 13]]
[[21 22 23]
[31 32 33]]
[[41 42 43]
[51 52 53]]
[[61 62 63]
[71 72 73]]]
切取後大小:
(1, 1, 3)
切取後數據:
[[[51 52 53]]]
可以看到,原始的列表 data 是三維的,即[4, 2, 3],最後切取到的大小也是三維的[1 ,1, 3],切取到的是 [[[51 52 53]]]。
每個維度切取的起始位置分別是[2, 1, 0],並分別在每個維度切取[1,1,3]的大小。(注意python列表的下標是從0開始的)
我們先把最外層的中括號去掉,並以逗號分隔,得到原 data 列表第一維。
第一維是三個元素,分別是:
- [[1, 2, 3], [11, 12, 13]]
- [[21, 22, 23], [31, 32, 33]]
- [[41, 42, 43], [51, 52, 53]]
- [[61, 62, 63], [71, 72, 73]]
如上介紹,第一維切取的起始位置是3,切取1的大小。所以只得到第一維的第三個元素,即[[41, 42, 43], [51, 52, 53]] 。
接下來我們繼續去掉最外層的中括號,以逗號分隔得到data列表第一維切取後的第二維如下:
- [41, 42, 43]
- [51, 52, 53]
如上介紹,第二維切取的起始位置是2,切取1的大小。所以得到第二維的第2個元素,即 [51, 52, 53] 。
第三維從第1個位置開始切,切取3個元素,得到 51, 52, 53。
4 進階
如何在不知道列表各個維度大小的情況下進行切取?
舉例如下:
import tensorflow as tf
data = tf.constant([[[1, 2, 3], [11, 12, 13]], [[21, 22, 23], [31, 32, 33]], [[41, 42, 43], [51, 52, 53]], [[61, 62, 63], [71, 72, 73]]])
slice_ = tf.slice(data, [0, 1, 0], [-1, -1, 2])
with tf.Session() as sess:
data, slice_ = sess.run([data, slice_])
print('原始大小:\n', data.shape)
print('原始數據:\n',data)
print('切取後大小:\n', slice_.shape)
print('切取後數據:\n',slice_)
輸出爲:
原始大小:
(4, 2, 3)
原始數據:
[[[ 1 2 3]
[11 12 13]]
[[21 22 23]
[31 32 33]]
[[41 42 43]
[51 52 53]]
[[61 62 63]
[71 72 73]]]
切取後大小:
(4, 1, 2)
切取後數據:
[[[11 12]]
[[31 32]]
[[51 52]]
[[71 72]]]
-1可指代最後的一位。
具體請讀者自行揣摩。