numpy的axis

知乎上看到的超好的講解文章,摘錄過來了。原作者是射命丸咲,原文地址是:Python · numpy · axis

摘錄如下(些微做了一丁點刪減):

要想學習 axis,首先要知道的就是 axis 的計數方式。我們在使用 numpy 的各種函數——比如說 np.sum——的時候,有一個參數就叫做 axis。那麼這個參數的意思是什麼呢?最直白地來說的話,就是“最外面的括號代表着 axis=0,依次往裏的括號對應的 axis 的計數就依次加 1

舉個例子,現在我們有一個二維矩陣在 numpy 裏面是這樣被表達出來的:x = [ [0, 1], [2, 3] ],然後 axis 的對應方式就是:

所以相應的運算就是:

對應的代碼實現和運行結果如下:

可以看到,貌似出來的結果比我們推導的結果的括號要少一些。這是因爲諸如 np.sum 這種函數中有一個參數叫 keepdims,它的默認值是 False,此時它會把多餘的括號給刪掉。假如我們把它設爲 True 的話,就可以得到和推導中一致的結果了:

下面來看一個更“高維”一點的例子:

對應的代碼實現和運行結果如下:

以及

可以看到結果和我們推導的確實一樣

現在我們知道哪個 axis 對應於數組中的哪些元素了,接下來還需要知道的就是 transpose 這個函數到底在背後幹了什麼。從紙面上來看,如果一個高維數組 x 的 shape 是 (2, 3, 4),那麼 transpose 的作用就是把這個 shape 中各個數的順序改一改。比如說:

但是 transpose 返回的結果究竟是如何得到的?

首先是對這個 shape 的理解。直觀地說,shape 中的各個數就是對應 axis 的元素個數。比如說上圖中的 x,它畫出來會是這個樣子的:

如果我們換一種思路的話,以 axis=0 爲例,由於我們現在整個數組裏面一共有 24 個數,而 axis=0 只有兩個元素,所以可以理解爲在 axis=0 這個 axis 上,每隔 24 / 2 = 12 個數就跳一下。比如說上面這個圖中就可以看出,兩個橙色矩陣對應的數之間差的都是 12

類似的,由於一個橙色矩陣中只有 24 / 2 = 12 個數,所以我們可以理解爲在 axis=1 這個 axis 上,每隔 12 / 3 = 4 個數就跳一下。表現在圖中,就是同一個橙色矩陣的兩個相鄰的藍色向量對應的數之間差的都是 4

再次類似的,由於一個藍色向量中只有 12 / 3 = 4 個數,我們可以理解爲在 axis=2 這個 axis 上,每隔 4 / 4 = 1 個數就跳一下。表現在圖中……想必也知道是怎樣的了......

所以我們現在可以定義一個新的東西,比如說叫做 strides 吧,它記錄着每個 axis 上跳過的數。比如說上圖對應的三維數組,它的 strides 就是 (12, 4, 1)

那麼接下來激動人心的時刻到了:transpose 的本質,其實就是對 strides 中各個數的順序進行調換。舉個例子:

在 transpose(1, 0, 2) 後,相應的 strides 會變成 (4, 12, 1)。而從上圖可以看出,transpose 的結果確實滿足:

  • axis=0 的 axis 上,每隔 4 個數跳一下(對應的數字都差4,例如x[0,0,0]和x[1,0,0]差了4)
  • axis=1 的 axis 上,每隔 12 個數跳一下(對應的數字都差12,例如x[0,0,0]和x[0,1,0]差了12)
  • axis=2 的 axis 上,每隔 1 個數跳一下(對應的數字都差12,例如x[0,0,0]和x[0,0,1]差了1)

至此,transpose 背後的邏輯就理順啦!撒花!*★,°*:.☆\( ̄▽ ̄)/$:*.°★* 。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章