tf.slice的理解

1 源代碼註釋的解釋

  • This operation extracts a slice of size size from a tensor input starting at the location specified by begin. The slice size is represented as tensor shape, where size[i] is the number of elements of the 'i’th dimension of input that you want to slice. The starting location (begin) for the slice is represented as an offset in each dimension of input. In other words, begin[i] is the offset into the 'i’th dimension of input that you want to slice from.

2 tf.slice函數的參數

slice(input, begin, size, name)
input:輸入的tensor變量。
begin:每個維度的起始位置。
size:從每個維度切取的大小。

3 實例解釋

測試代碼如下:

import tensorflow as tf

data = tf.constant([[[1, 2, 3], [11, 12, 13]], [[21, 22, 23], [31, 32, 33]], [[41, 42, 43], [51, 52, 53]], [[61, 62, 63], [71, 72, 73]]])
slice_ = tf.slice(data, [2, 1, 0], [1, 1, 3])
with tf.Session() as sess:
    data, slice_ = sess.run([data, slice_])
    print('原始大小:\n', data.shape)
    print('原始數據:\n',data)
    print('切取後大小:\n', slice_.shape)
    print('切取後數據:\n',slice_)

結果輸出如下:

原始大小:
 (4, 2, 3)
原始數據:
 [[[ 1  2  3]
  [11 12 13]]

 [[21 22 23]
  [31 32 33]]

 [[41 42 43]
  [51 52 53]]

 [[61 62 63]
  [71 72 73]]]
切取後大小:
 (1, 1, 3)
切取後數據:
 [[[51 52 53]]]

可以看到,原始的列表 data 是三維的,即[4, 2, 3],最後切取到的大小也是三維的[1 ,1, 3],切取到的是 [[[51 52 53]]]。

每個維度切取的起始位置分別是[2, 1, 0],並分別在每個維度切取[1,1,3]的大小。(注意python列表的下標是從0開始的)

我們先把最外層的中括號去掉,並以逗號分隔,得到原 data 列表第一維。

第一維是三個元素,分別是:

  • [[1, 2, 3], [11, 12, 13]]
  • [[21, 22, 23], [31, 32, 33]]
  • [[41, 42, 43], [51, 52, 53]]
  • [[61, 62, 63], [71, 72, 73]]

如上介紹,第一維切取的起始位置是3,切取1的大小。所以只得到第一維的第三個元素,即[[41, 42, 43], [51, 52, 53]] 。

接下來我們繼續去掉最外層的中括號,以逗號分隔得到data列表第一維切取後的第二維如下:

  • [41, 42, 43]
  • [51, 52, 53]

如上介紹,第二維切取的起始位置是2,切取1的大小。所以得到第二維的第2個元素,即 [51, 52, 53] 。

第三維從第1個位置開始切,切取3個元素,得到 51, 52, 53。

4 進階

如何在不知道列表各個維度大小的情況下進行切取?

舉例如下:

import tensorflow as tf

data = tf.constant([[[1, 2, 3], [11, 12, 13]], [[21, 22, 23], [31, 32, 33]], [[41, 42, 43], [51, 52, 53]], [[61, 62, 63], [71, 72, 73]]])
slice_ = tf.slice(data, [0, 1, 0], [-1, -1, 2])
with tf.Session() as sess:
    data, slice_ = sess.run([data, slice_])
    print('原始大小:\n', data.shape)
    print('原始數據:\n',data)
    print('切取後大小:\n', slice_.shape)
    print('切取後數據:\n',slice_)

輸出爲:

原始大小:
 (4, 2, 3)
原始數據:
 [[[ 1  2  3]
  [11 12 13]]

 [[21 22 23]
  [31 32 33]]

 [[41 42 43]
  [51 52 53]]

 [[61 62 63]
  [71 72 73]]]
切取後大小:
 (4, 1, 2)
切取後數據:
 [[[11 12]]

 [[31 32]]

 [[51 52]]

 [[71 72]]]

-1可指代最後的一位。

具體請讀者自行揣摩。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章