tf.squeeze()的解析

squeeze(
    input,
    axis=None,
    name=None,
    squeeze_dims=None
)

該函數返回一個張量,這個張量是將原始input中所有維度爲1的那些維都刪掉的結果
axis可以用來指定要刪掉的爲1的維度,此處要注意指定的維度必須確保其是1,否則會報錯。

>>>y = tf.squeeze(inputs, [0, 1], name='squeeze')
>>>ValueError: Can not squeeze dim[0], expected a dimension of 1, got 32 for 'squeeze' (op: 'Squeeze') with input shapes: [32,1,1,3].

例子:

#  't' 是一個維度是[1, 2, 1, 3, 1, 1]的張量
tf.shape(tf.squeeze(t))   # [2, 3], 默認刪除所有爲1的維度

# 't' 是一個維度[1, 2, 1, 3, 1, 1]的張量
tf.shape(tf.squeeze(t, [2, 4]))  # [1, 2, 3, 1],標號從零開始,只刪掉了2和4維的1

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章