在遇到形如tf.argmax(logits, axis=-1)
的代碼時,axis參數的含義非常容易令人疑惑。在二維情形下,axis=0
表示求每列的最大值的下標,axis=1
表示求每行最大值的下標。但是在更高維度下呢?
我們不妨假設數組A滿足A.shape=(2,4,8,16)
,研究A生成的數組(Ax=argmax(A,axis=x)
)的shape,結果如下表所示:
A0.shape | A1.shape | A2.shape | A3.shape |
---|---|---|---|
(4,8,16) | (2,8,16) | (2,4,16) | (2,4,8) |
不難發現,Ax.shape比原來的A.shape恰好少掉了第x個值。事實上,axis=x就表示將原數組的第x維壓縮,最終得到的第x維上各值的argmax。
特別地,我們知道在python中對一個長度的n的數組A,我們有A[n-i]=a[-i],同理,我們常見的axis=-1事實上就是表示壓縮最後一個維度,在常見的二維數組中,就表示對壓縮列,即對每行求和。
在其他一些降維的操作中,axis有着類似的作用。