在使用numpy多維數組時我們常會需要獲取數組中的元素,這一般有兩種方法:
import numpy as np
a = np.random.randint(10, 20, size=[10, 20])
print(a)
print(a[2, 2])
print(a[2][2])
'''
[[17 12 16 14 19 10 14 13 15 13 17 19 19 11 11 18 16 12 16 17]
[15 17 15 11 19 16 19 18 12 12 19 19 15 19 18 11 18 12 10 13]
[10 12 13 12 14 11 12 12 10 18 14 16 16 16 14 13 11 15 11 15]
[10 17 19 11 16 17 15 11 14 12 17 14 15 17 12 17 16 15 10 14]
[13 14 13 16 14 18 14 16 16 10 11 13 14 14 12 11 18 12 14 13]
[17 19 14 15 19 12 10 17 14 13 19 11 17 13 17 10 19 14 18 11]
[17 11 13 18 14 17 14 18 11 18 18 14 16 19 18 18 15 18 15 12]
[19 17 10 13 14 12 19 16 10 18 11 11 12 18 16 15 15 13 15 19]
[18 12 11 15 11 13 13 18 15 19 19 13 16 13 19 15 10 12 10 15]
[14 19 12 10 11 10 14 19 12 10 19 12 18 15 18 17 19 12 18 14]]
13
13
'''
但當我們需要用到切片時,第二種寫法卻是錯誤的:
import numpy as np
a = np.random.randint(10, 20, size=[10, 20])
print(a)
print(a[:2, :2])
print(a[:2][:2])
'''
[[16 16 12 11 11 16 13 13 18 17 12 10 15 12 19 12 19 18 11 17]
[16 13 15 12 14 18 18 19 15 16 10 17 19 15 15 18 14 17 17 18]
[12 10 10 12 13 18 10 13 14 13 13 19 10 16 13 19 13 19 13 17]
[17 19 11 10 11 17 16 10 18 15 10 18 12 15 17 11 16 13 12 11]
[17 19 11 13 12 15 14 16 12 12 14 11 15 13 19 19 17 14 16 19]
[10 15 18 19 10 15 12 13 11 18 19 11 14 15 14 17 15 10 13 16]
[10 16 17 18 19 14 15 10 14 11 11 16 18 15 16 12 10 11 16 18]
[13 12 13 13 16 16 17 16 15 15 15 16 14 16 15 16 19 14 19 14]
[16 13 17 16 10 15 19 15 19 13 19 12 16 11 14 17 18 19 15 15]
[14 10 19 14 10 11 16 14 10 16 18 12 10 14 12 10 12 14 10 15]]
[[16 16]
[16 13]]
[[16 16 12 11 11 16 13 13 18 17 12 10 15 12 19 12 19 18 11 17]
[16 13 15 12 14 18 18 19 15 16 10 17 19 15 15 18 14 17 17 18]]
'''
第一種寫法獲取到了我們實際想要的子矩陣,而第二種寫法實際上需要分開來看待:先獲取a的前兩行得到一個子矩陣,再獲取這個子矩陣的前兩行。
最近寫代碼時總弄混這兩個寫法,因此記錄一下,numpy切片的正確用法是用逗號隔開,而不是像多維數組索引那樣隔開。
今天又發現了新的問題,numpy真是有趣。在做cs231n的作業時,我需要從一個的分數中按照一個的label來取出的正確分數(每行按照label選一個分數),自然會想到花式索引和切片結合的方法,但遇到了一些問題,這裏總結一下可能的寫法:
a = np.random.randint(5, 10, size=(5, 10))
print(a)
y1 = np.random.randint(0, 10, size=(5, ))
print(y1)
y2 = np.random.randint(0, 10, size=(5, 1))
print(y2)
'''
[[7 6 5 7 9 5 9 6 8 8]
[8 8 8 7 9 9 8 5 9 8]
[8 9 8 7 5 7 7 6 8 8]
[6 9 5 9 5 7 7 8 7 7]
[9 6 6 9 9 6 7 9 7 6]]
[8 1 0 9 8]
[[5]
[3]
[5]
[9]
[5]]
'''
可以看到y1和y2的shape是不一樣的,y1是一個數組,y2則是一個二維矩陣。
print(a[:, y1])
'''
[[8 6 7 8 8]
[9 8 8 8 9]
[8 9 8 8 8]
[7 9 6 7 7]
[7 6 9 6 7]]
可以看到,這種寫法得到一個N * N的矩陣,每一行對應a的每一行按照y1的所有元素來取值,
即本來a的每一行取一個值就可以,但是卻取了N個值,每一行相當於a[i, y1]
a[0, y1] = [8 6 7 8 8]
'''
print(a[range(5), y1])
'''
[8 8 8 7 7]
這種寫法就是我們想要的結果
'''
print(a[:, y2])
'''
[[[5]
[7]
[5]
[8]
[5]]
[[9]
[7]
[9]
[8]
[9]]
[[7]
[7]
[7]
[8]
[7]]
[[7]
[9]
[7]
[7]
[7]]
[[6]
[9]
[6]
[6]
[6]]]
這種寫法得到的結果更加離譜,是一個5 * 5 * 1的三維矩陣,
每個5 * 1的子矩陣相當於第一種寫法的結果
'''
print(a[range(5), y2])
'''
[[5 9 7 7 6]
[7 7 7 9 9]
[5 9 7 7 6]
[8 8 8 7 6]
[5 9 7 7 6]]
一共5行,每一行都是a的某一列,按照y2來取值。
'''
結論就是要使用range
和一維數組來進行切片和花式索引。