tf.slice的理解

1 源代码注释的解释

  • This operation extracts a slice of size size from a tensor input starting at the location specified by begin. The slice size is represented as tensor shape, where size[i] is the number of elements of the 'i’th dimension of input that you want to slice. The starting location (begin) for the slice is represented as an offset in each dimension of input. In other words, begin[i] is the offset into the 'i’th dimension of input that you want to slice from.

2 tf.slice函数的参数

slice(input, begin, size, name)
input:输入的tensor变量。
begin:每个维度的起始位置。
size:从每个维度切取的大小。

3 实例解释

测试代码如下:

import tensorflow as tf

data = tf.constant([[[1, 2, 3], [11, 12, 13]], [[21, 22, 23], [31, 32, 33]], [[41, 42, 43], [51, 52, 53]], [[61, 62, 63], [71, 72, 73]]])
slice_ = tf.slice(data, [2, 1, 0], [1, 1, 3])
with tf.Session() as sess:
    data, slice_ = sess.run([data, slice_])
    print('原始大小:\n', data.shape)
    print('原始数据:\n',data)
    print('切取后大小:\n', slice_.shape)
    print('切取后数据:\n',slice_)

结果输出如下:

原始大小:
 (4, 2, 3)
原始数据:
 [[[ 1  2  3]
  [11 12 13]]

 [[21 22 23]
  [31 32 33]]

 [[41 42 43]
  [51 52 53]]

 [[61 62 63]
  [71 72 73]]]
切取后大小:
 (1, 1, 3)
切取后数据:
 [[[51 52 53]]]

可以看到,原始的列表 data 是三维的,即[4, 2, 3],最后切取到的大小也是三维的[1 ,1, 3],切取到的是 [[[51 52 53]]]。

每个维度切取的起始位置分别是[2, 1, 0],并分别在每个维度切取[1,1,3]的大小。(注意python列表的下标是从0开始的)

我们先把最外层的中括号去掉,并以逗号分隔,得到原 data 列表第一维。

第一维是三个元素,分别是:

  • [[1, 2, 3], [11, 12, 13]]
  • [[21, 22, 23], [31, 32, 33]]
  • [[41, 42, 43], [51, 52, 53]]
  • [[61, 62, 63], [71, 72, 73]]

如上介绍,第一维切取的起始位置是3,切取1的大小。所以只得到第一维的第三个元素,即[[41, 42, 43], [51, 52, 53]] 。

接下来我们继续去掉最外层的中括号,以逗号分隔得到data列表第一维切取后的第二维如下:

  • [41, 42, 43]
  • [51, 52, 53]

如上介绍,第二维切取的起始位置是2,切取1的大小。所以得到第二维的第2个元素,即 [51, 52, 53] 。

第三维从第1个位置开始切,切取3个元素,得到 51, 52, 53。

4 进阶

如何在不知道列表各个维度大小的情况下进行切取?

举例如下:

import tensorflow as tf

data = tf.constant([[[1, 2, 3], [11, 12, 13]], [[21, 22, 23], [31, 32, 33]], [[41, 42, 43], [51, 52, 53]], [[61, 62, 63], [71, 72, 73]]])
slice_ = tf.slice(data, [0, 1, 0], [-1, -1, 2])
with tf.Session() as sess:
    data, slice_ = sess.run([data, slice_])
    print('原始大小:\n', data.shape)
    print('原始数据:\n',data)
    print('切取后大小:\n', slice_.shape)
    print('切取后数据:\n',slice_)

输出为:

原始大小:
 (4, 2, 3)
原始数据:
 [[[ 1  2  3]
  [11 12 13]]

 [[21 22 23]
  [31 32 33]]

 [[41 42 43]
  [51 52 53]]

 [[61 62 63]
  [71 72 73]]]
切取后大小:
 (4, 1, 2)
切取后数据:
 [[[11 12]]

 [[31 32]]

 [[51 52]]

 [[71 72]]]

-1可指代最后的一位。

具体请读者自行揣摩。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章