tf.squeeze()函數

tf.squeeze()函數用於從張量形狀中移除大小爲1的維度

 

squeeze(
    input,
    axis=None,
    name=None,
    squeeze_dims=None
)

給定張量輸入,此操作返回相同類型的張量,並刪除所有維度爲1的維度。 如果不想刪除所有維度1維度,可以通過指定squeeze_dims來刪除特定維度1維度。

如果不想刪除所有大小是1的維度,可以通過squeeze_dims指定。

參數:

    input:A Tensor。輸入要擠壓。

    axis:一個可選列表ints。默認爲[]。如果指定,只能擠壓列出的尺寸。維度索引從0開始。壓縮非1的維度是錯誤的。必須在範圍內[-- rank(input), rank(input))。

    name:操作的名稱(可選)。

    squeeze_dims:現在是軸的已棄用的關鍵字參數。

函數返回值:

一Tensor。與輸入類型相同。 包含與輸入相同的數據,但具有一個或多個刪除尺寸1的維度。

可能引發的異常:

    ValueError:當兩個squeeze_dims和axis指定。

例子1:

該函數返回一個張量,這個張量是將原始input中所有維度爲1的那些維都刪掉的結果。

axis可以用來指定要刪掉的爲1的維度,此處要注意指定的維度必須確保其是1,否則會報錯。

import tensorflow as tf
import numpy as np

value = np.floor(10*np.random.random((3,2,2)))
with tf.Session() as sess:
    tf.squeeze(value)
    print(sess.run(tf.shape(tf.squeeze(value, [1]))))
ValueError: Can not squeeze dim[1], expected a dimension of 1, got 2 for 'Squeeze_17' (op: 'Squeeze') with input shapes: [3,2,2].

例子2:

默認刪除所有維度是1的維度。

import tensorflow as tf
import numpy as np

value = np.floor(10*np.random.random((1,3,2,1,2)))
with tf.Session() as sess:
    print(sess.run(tf.shape(tf.squeeze(value))))
[3 2 2]

例子3:

如果不想刪除所有尺寸1尺寸,可以通過指定axis來刪除特定維度1的維度。

import tensorflow as tf
import numpy as np

value = np.floor(10*np.random.random((1,3,2,1,2)))
with tf.Session() as sess:
    print(sess.run(tf.shape(tf.squeeze(value, [0]))))
[3 2 1 2]

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章