Tensorlfow增加或者減小矩陣的維度的兩個方法

感覺這個不常用,但用到了不會就很尷尬了!

增加維度主要是用tf.expand_dims()這個方法,降低維度用的是tf.squeeze()這個方法。下面我放上這兩個函數的源代碼,其實自己稍微看一下就懂了。

def expand_dims(input, axis=None, name=None, dim=None):
  """Inserts a dimension of 1 into a tensor's shape.

  Given a tensor `input`, this operation inserts a dimension of 1 at the
  dimension index `axis` of `input`'s shape. The dimension index `axis` starts
  at zero; if you specify a negative number for `axis` it is counted backward
  from the end.

  This operation is useful if you want to add a batch dimension to a single
  element. For example, if you have a single image of shape `[height, width,
  channels]`, you can make it a batch of 1 image with `expand_dims(image, 0)`,
  which will make the shape `[1, height, width, channels]`.

  Other examples:
  #這是給的具體的例子,0代表在第0維插入一維,1代表在1維的位置插入一維

  ```python
  # 't' is a tensor of shape [2]
  tf.shape(tf.expand_dims(t, 0))  # [1, 2]
  tf.shape(tf.expand_dims(t, 1))  # [2, 1]
  tf.shape(tf.expand_dims(t, -1))  # [2, 1]

  # 't2' is a tensor of shape [2, 3, 5]
  tf.shape(tf.expand_dims(t2, 0))  # [1, 2, 3, 5]
  tf.shape(tf.expand_dims(t2, 2))  # [2, 3, 1, 5]
  tf.shape(tf.expand_dims(t2, 3))  # [2, 3, 5, 1]
  ```

  This operation requires that:

  `-1-input.dims() <= dim <= input.dims()`

  This operation is related to `squeeze()`, which removes dimensions of
  size 1.

  Args:
    input: A `Tensor`.
    axis: 0-D (scalar). Specifies the dimension index at which to
      expand the shape of `input`. Must be in the range
      `[-rank(input) - 1, rank(input)]`.
    name: The name of the output `Tensor`.
    dim: 0-D (scalar). Equivalent to `axis`, to be deprecated.

  Returns:
    A `Tensor` with the same data as `input`, but its shape has an additional
    dimension of size 1 added.

  Raises:
    ValueError: if both `dim` and `axis` are specified.
  """
  axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
  return gen_array_ops.expand_dims(input, axis, name)

 

另一個函數的代碼:

def squeeze(input, axis=None, name=None, squeeze_dims=None):
  # pylint: disable=redefined-builtin
  """Removes dimensions of size 1 from the shape of a tensor.

  Given a tensor `input`, this operation returns a tensor of the same type with
  all dimensions of size 1 removed. If you don't want to remove all size 1
  dimensions, you can remove specific size 1 dimensions by specifying
  `axis`.

  For example:

  ```python
  # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
  tf.shape(tf.squeeze(t))  # [2, 3]
  ```

  Or, to remove specific size 1 dimensions:

  ```python
  # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
  tf.shape(tf.squeeze(t, [2, 4]))  # [1, 2, 3, 1]
  ```

  Args:
    input: A `Tensor`. The `input` to squeeze.
    axis: An optional list of `ints`. Defaults to `[]`.
      If specified, only squeezes the dimensions listed. The dimension
      index starts at 0. It is an error to squeeze a dimension that is not 1.
      Must be in the range `[-rank(input), rank(input))`.
    name: A name for the operation (optional).
    squeeze_dims: Deprecated keyword argument that is now axis.

  Returns:
    A `Tensor`. Has the same type as `input`.
    Contains the same data as `input`, but has one or more dimensions of
    size 1 removed.

  Raises:
    ValueError: When both `squeeze_dims` and `axis` are specified.
  """
  axis = deprecation.deprecated_argument_lookup(
      "axis", axis, "squeeze_dims", squeeze_dims)
  if np.isscalar(axis):
    axis = [axis]
  return gen_array_ops.squeeze(input, axis, name)

 

 


 

 
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章