python axis參數解析

在遇到形如tf.argmax(logits, axis=-1)的代碼時,axis參數的含義非常容易令人疑惑。在二維情形下,axis=0表示求每列的最大值的下標,axis=1表示求每行最大值的下標。但是在更高維度下呢?

我們不妨假設數組A滿足A.shape=(2,4,8,16),研究A生成的數組(Ax=argmax(A,axis=x))的shape,結果如下表所示:

A0.shape A1.shape A2.shape A3.shape
(4,8,16) (2,8,16) (2,4,16) (2,4,8)

不難發現,Ax.shape比原來的A.shape恰好少掉了第x個值。事實上,axis=x就表示將原數組的第x維壓縮,最終得到的第x維上各值的argmax。

特別地,我們知道在python中對一個長度的n的數組A,我們有A[n-i]=a[-i],同理,我們常見的axis=-1事實上就是表示壓縮最後一個維度,在常見的二維數組中,就表示對壓縮列,即對每行求和。

在其他一些降維的操作中,axis有着類似的作用。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章