tf.squeeze()函數用於從張量形狀中移除大小爲1的維度
squeeze(
input,
axis=None,
name=None,
squeeze_dims=None
)
給定張量輸入,此操作返回相同類型的張量,並刪除所有維度爲1的維度。 如果不想刪除所有維度1維度,可以通過指定squeeze_dims來刪除特定維度1維度。
如果不想刪除所有大小是1的維度,可以通過squeeze_dims指定。
參數:
input:A Tensor。輸入要擠壓。
axis:一個可選列表ints。默認爲[]。如果指定,只能擠壓列出的尺寸。維度索引從0開始。壓縮非1的維度是錯誤的。必須在範圍內[-- rank(input), rank(input))。
name:操作的名稱(可選)。
squeeze_dims:現在是軸的已棄用的關鍵字參數。
函數返回值:
一Tensor。與輸入類型相同。 包含與輸入相同的數據,但具有一個或多個刪除尺寸1的維度。
可能引發的異常:
ValueError:當兩個squeeze_dims和axis指定。
例子1:
該函數返回一個張量,這個張量是將原始input中所有維度爲1的那些維都刪掉的結果。
axis可以用來指定要刪掉的爲1的維度,此處要注意指定的維度必須確保其是1,否則會報錯。
import tensorflow as tf
import numpy as np
value = np.floor(10*np.random.random((3,2,2)))
with tf.Session() as sess:
tf.squeeze(value)
print(sess.run(tf.shape(tf.squeeze(value, [1]))))
ValueError: Can not squeeze dim[1], expected a dimension of 1, got 2 for 'Squeeze_17' (op: 'Squeeze') with input shapes: [3,2,2].
例子2:
默認刪除所有維度是1的維度。
import tensorflow as tf
import numpy as np
value = np.floor(10*np.random.random((1,3,2,1,2)))
with tf.Session() as sess:
print(sess.run(tf.shape(tf.squeeze(value))))
[3 2 2]
例子3:
如果不想刪除所有尺寸1尺寸,可以通過指定axis來刪除特定維度1的維度。
import tensorflow as tf
import numpy as np
value = np.floor(10*np.random.random((1,3,2,1,2)))
with tf.Session() as sess:
print(sess.run(tf.shape(tf.squeeze(value, [0]))))
[3 2 1 2]