本文轉自知乎周康生的文章Numpy:對Axis的理解,轉載無刪改。
- Axis就是數組層級
- 設axis=i,則Numpy沿着第i個下標變化的方向進行操作
- Axis的應用
Axis就是數組層級
要想理解axis,首先我們先要弄清楚“Numpy中數組的維數”和"線性代數中矩陣的維數"這兩個概念以及它們之間的關係。在數學或者物理的概念中,dimensions被認爲是在空間中表示一個點所需要的最少座標個數,但是在Numpy中,dimensions指代的是數組的維數。比如下面這個例子:
>>> import numpy as np
>>> a = np.array([[1,2,3],[2,3,4],[3,4,9]])
>>> a
array([[1, 2, 3],
[2, 3, 4],
[3, 4, 9]])
這個array的維數只有2,即axis軸有兩個,分別是axis=0和axis=1。如下圖所示,該二維數組的第0維(axis=0)有三個元素(左圖),即axis=0軸的長度length爲3;第1維(axis=1)也有三個元素(右圖),即axis=1軸的長度length爲3。正是因爲axis=0、axis=1的長度都爲3,矩陣橫着豎着都有3個數,所以該矩陣在線性代數是3維的(rank秩爲3)。
因此,axis就是數組層級。
當axis=0,該軸上的元素有3個(數組的size爲3)
a[0]
、a[1]
、a[2]
當axis=1,該軸上的元素有3個(數組的size爲3)
a[0][0]
、a[0][1]
、a[0][2]
(或者a[1][0]
、a[1][1]
、a[1][2]
)
(或者a[2][0]
、a[2][1]
、a[2][2]
)
再比如下面shape爲(3,2,4)的array:
>>> b = np.array([[[1,2,3,4],[1,3,4,5]],[[2,4,7,5],[8,4,3,5]],[[2,5,7,3],[1,5,3,7]]])
>>> b
array([[[1, 2, 3, 4],
[1, 3, 4, 5]],
[[2, 4, 7, 5],
[8, 4, 3, 5]],
[[2, 5, 7, 3],
[1, 5, 3, 7]]])
>>> b.shape
(3, 2, 4)
這個shape(用tuple表示)可以理解爲在每個軸(axis)上的size,也即佔有的長度(length)。爲了更進一步理解,我們可以暫時把多個axes想象成多層layers。axis=0表示第一層(下圖黑色框框),該層數組的size爲3,對應軸上的元素length = 3;axis=1表示第二層(下圖紅色框框),該層數組的size爲2,對應軸上的元素length = 2;axis=2表示第三層(下圖藍色框框),對應軸上的元素length = 4。
設axis=i,則Numpy沿着第i個下標變化的方向進行操作
1.二維數組示例:
比如np.sum(a, axis=1)
,結合下面的數組, a[0][0]
=1、a[0][1]
=2、a[0][2]
=3 ,下標會發生變化的方向是數組的第一維。
我們往下標會變化的方向,把元素相加後即可得到最終結果:
[
[6],
[9],
[16]
]
2.三維數組示例:
再舉個例子,比如下邊這個np.shape(a)=(3,2,4)
的3維數組,該數組第0維的長度爲3(黑色框框),再深入一層,第1維的長度爲2(紅色框框),再深入一層,第2維的長度爲4(藍色框框)。
如果我們要計算np.sum(a, axis=1)
,在第一個黑色框框中,
下標的變化方向如下所示:
所以,我們要把上下兩個紅色框框相加起來
按照同樣的邏輯處理第二個和第三個黑色的框框,可以得出最終結果:
所以,依然是我們前邊總結的那一句話,設axis=i,則Numpy沿着第i個下標變化的方向進行操作。
3.四維數組示例:
比如下面這個巨複雜的4維數組,
>>> data = np.random.randint(0, 5, [4,3,2,3])
>>> data
array([[[[4, 1, 0],
[4, 3, 0]],
[[1, 2, 4],
[2, 2, 3]],
[[4, 3, 3],
[4, 2, 3]]],
[[[4, 0, 1],
[1, 1, 1]],
[[0, 1, 0],
[0, 4, 1]],
[[1, 3, 0],
[0, 3, 0]]],
[[[3, 3, 4],
[0, 1, 0]],
[[1, 2, 3],
[4, 0, 4]],
[[1, 4, 1],
[1, 3, 2]]],
[[[0, 1, 1],
[2, 4, 3]],
[[4, 1, 4],
[1, 4, 1]],
[[0, 1, 0],
[2, 4, 3]]]])
當axis=0時,numpy沿着第0維的方向進行求和,也就是第一個元素值=a0000+a1000+a2000+a3000=11,第二個元素=a0001+a1001+a2001+a3001=5,同理可得最後的結果如下:
>>> data.sum(axis=0)
array([[[11, 5, 6],
[ 7, 9, 4]],
[[ 6, 6, 11],
[ 7, 10, 9]],
[[ 6, 11, 4],
[ 7, 12, 8]]])
當axis=3時,numpy沿着第3維的方向進行求和,也就是第一個元素值=a0000+a0001+a0002=5,第二個元素=a0010+a0011+a0012=7,同理可得最後的結果如下:
>>> data.sum(axis=3)
array([[[ 5, 7],
[ 7, 7],
[10, 9]],
[[ 5, 3],
[ 1, 5],
[ 4, 3]],
[[10, 1],
[ 6, 8],
[ 6, 6]],
[[ 2, 9],
[ 9, 6],
[ 1, 9]]])
Axis的應用
例如現在我們收集了四個同學對蘋果、榴蓮、西瓜這三種水果的喜愛程度進行打分的數據(總分爲10),每個同學都有三個特徵:
>>> item = np.array([[1,4,8],[2,3,5],[2,5,1],[1,10,7]])
>>> item
array([[1, 4, 8],
[2, 3, 5],
[2, 5, 1],
[1, 10, 7]])
每一行包含了同一個人的三個特徵,如果我們想看看哪個同學最喜歡喫水果,那就可以用:
>>> item.sum(axis = 1)
array([13, 10, 8, 18])
可以大概看出來同學4最喜歡喫水果。
如果我們想看看哪種水果最受歡迎,那就可以用:
>>> item.sum(axis = 0)
array([ 6, 22, 21])
可以看出基本是榴蓮最受歡迎。