理解numpy中的axis(軸)

在numpy中axis是一個比較難理解的點,在很長一段時間我都是在處理一些2維的數組,所以往往對這塊知識有所忽略,直到我在做斯坦福的cs231n的assignment時候,纔對axis有了更加深入的理解。

先來看一組簡單的代碼。

import numpy as np

t = np.arange(8).reshape(2, 4)
print("origin: ", t, " shape: ", t.shape)
print("sum: ", t.sum(0), " shape: ", t.sum(0).shape)
print("sum: ", t.sum(1), " shape: ", t.sum(1).shape)

結果如下:
origin: [[0 1 2 3]
[4 5 6 7]] shape: (2, 4)
sum: [ 4 6 8 10] shape: (4,)
sum: [ 6 22] shape: (2,)

首先確定一點,axis=0是shape中從左往右數的第0個軸也就是(“2”,4)加引號部分,以此類推
我們要求和的axis就是將該軸消去 ,這裏的t.sum(0)即消去0軸故只剩下 (4,),sum(1)同理

現在從結果來看,以sum(1)爲例 寫出下標 6: (0,x) 22 :(1,x) 這裏寫出x表示我們消去的軸
現在我們可以知道其實就是執行一個循環將所有0軸相同的元素加起來
6:(0,x) = 0(0,0)+1(0,1) +2(0,2)+3(0,3)
22:(1,x)=4(1,0)+5(1,1) +6(1,2)+7(1,3)
而 t.sum((0,1))=28(x,x)則是把兩個維度的都加起來,簡單不贅述

是不是還挺好理解的下面我們看下一組三維的數組

import numpy as np

t = np.arange(8).reshape(2, 2, 2)
print("origin: ", t, " shape: ", t.shape)
print("sum: ", t.sum(1), " shape: ", t.sum(0).shape)
print("sum: ", t.sum((1, 2)), " shape: ", t.sum(1).shape)

origin: [[[0 1]
[2 3]]

[[4 5]
[6 7]]] shape: (2, 2, 2)
sum: [[ 2 4]
[10 12]] shape: (2, 2)
sum: [ 6 22] shape: (2, 2)

這裏我將下標一一標出
t :
0: (0,0,0)
1: (0,0,1)
2: (0,1,0)
3: (0,1,1)
4: (1,0,0)
5: (1,0,1)
6: (1,1,0)
7: (1,1,1)

t.sum(1):
2:(0,x,0)
4:(0,x,1)
10:(1,x,0)
12:(1,x,1)

相信看到這裏已經很明瞭了,和2維的情況類似,x是被消去的軸,我們只要找到第0軸和第2軸相同元素相加即可

t.sum((1,2))
6:(0,x,x)
22:(1,x,x)

第0軸是其保留下來的軸,我們找到第0軸相同的元素全部相加,即可求得結果。
這裏只寫sum函數,其實mean,std,max等等函數都是一樣的操作。

如果這篇文章對你有幫助,順手點個贊,我會很開心!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章