numpy數組索引遇到的問題

在使用numpy多維數組時我們常會需要獲取數組中的元素,這一般有兩種方法:

import numpy as np

a = np.random.randint(10, 20, size=[10, 20])
print(a)
print(a[2, 2])
print(a[2][2])
'''
[[17 12 16 14 19 10 14 13 15 13 17 19 19 11 11 18 16 12 16 17]
 [15 17 15 11 19 16 19 18 12 12 19 19 15 19 18 11 18 12 10 13]
 [10 12 13 12 14 11 12 12 10 18 14 16 16 16 14 13 11 15 11 15]
 [10 17 19 11 16 17 15 11 14 12 17 14 15 17 12 17 16 15 10 14]
 [13 14 13 16 14 18 14 16 16 10 11 13 14 14 12 11 18 12 14 13]
 [17 19 14 15 19 12 10 17 14 13 19 11 17 13 17 10 19 14 18 11]
 [17 11 13 18 14 17 14 18 11 18 18 14 16 19 18 18 15 18 15 12]
 [19 17 10 13 14 12 19 16 10 18 11 11 12 18 16 15 15 13 15 19]
 [18 12 11 15 11 13 13 18 15 19 19 13 16 13 19 15 10 12 10 15]
 [14 19 12 10 11 10 14 19 12 10 19 12 18 15 18 17 19 12 18 14]]
13
13
 '''

但當我們需要用到切片時,第二種寫法卻是錯誤的:

import numpy as np

a = np.random.randint(10, 20, size=[10, 20])
print(a)
print(a[:2, :2])
print(a[:2][:2])
'''
[[16 16 12 11 11 16 13 13 18 17 12 10 15 12 19 12 19 18 11 17]
 [16 13 15 12 14 18 18 19 15 16 10 17 19 15 15 18 14 17 17 18]
 [12 10 10 12 13 18 10 13 14 13 13 19 10 16 13 19 13 19 13 17]
 [17 19 11 10 11 17 16 10 18 15 10 18 12 15 17 11 16 13 12 11]
 [17 19 11 13 12 15 14 16 12 12 14 11 15 13 19 19 17 14 16 19]
 [10 15 18 19 10 15 12 13 11 18 19 11 14 15 14 17 15 10 13 16]
 [10 16 17 18 19 14 15 10 14 11 11 16 18 15 16 12 10 11 16 18]
 [13 12 13 13 16 16 17 16 15 15 15 16 14 16 15 16 19 14 19 14]
 [16 13 17 16 10 15 19 15 19 13 19 12 16 11 14 17 18 19 15 15]
 [14 10 19 14 10 11 16 14 10 16 18 12 10 14 12 10 12 14 10 15]]
[[16 16]
 [16 13]]
[[16 16 12 11 11 16 13 13 18 17 12 10 15 12 19 12 19 18 11 17]
 [16 13 15 12 14 18 18 19 15 16 10 17 19 15 15 18 14 17 17 18]]
'''

第一種寫法獲取到了我們實際想要的子矩陣,而第二種寫法實際上需要分開來看待:先獲取a的前兩行得到一個子矩陣,再獲取這個子矩陣的前兩行。
最近寫代碼時總弄混這兩個寫法,因此記錄一下,numpy切片的正確用法是用逗號隔開,而不是像多維數組索引那樣隔開


今天又發現了新的問題,numpy真是有趣。在做cs231n的作業時,我需要從一個NCN * C的分數中按照一個N1N * 1的label來取出N1N * 1的正確分數(每行按照label選一個分數),自然會想到花式索引和切片結合的方法,但遇到了一些問題,這裏總結一下可能的寫法:

a = np.random.randint(5, 10, size=(5, 10))
print(a)

y1 = np.random.randint(0, 10, size=(5, ))
print(y1)

y2 = np.random.randint(0, 10, size=(5, 1))
print(y2)
'''
[[7 6 5 7 9 5 9 6 8 8]
 [8 8 8 7 9 9 8 5 9 8]
 [8 9 8 7 5 7 7 6 8 8]
 [6 9 5 9 5 7 7 8 7 7]
 [9 6 6 9 9 6 7 9 7 6]]
[8 1 0 9 8]
[[5]
 [3]
 [5]
 [9]
 [5]]
'''

可以看到y1和y2的shape是不一樣的,y1是一個數組,y2則是一個二維矩陣。

print(a[:, y1])
'''
[[8 6 7 8 8]
 [9 8 8 8 9]
 [8 9 8 8 8]
 [7 9 6 7 7]
 [7 6 9 6 7]]
 可以看到,這種寫法得到一個N * N的矩陣,每一行對應a的每一行按照y1的所有元素來取值,
 即本來a的每一行取一個值就可以,但是卻取了N個值,每一行相當於a[i, y1]
 a[0, y1] = [8 6 7 8 8]
'''
print(a[range(5), y1])
'''
[8 8 8 7 7]
這種寫法就是我們想要的結果
'''
print(a[:, y2])
'''
[[[5]
  [7]
  [5]
  [8]
  [5]]

 [[9]
  [7]
  [9]
  [8]
  [9]]

 [[7]
  [7]
  [7]
  [8]
  [7]]

 [[7]
  [9]
  [7]
  [7]
  [7]]

 [[6]
  [9]
  [6]
  [6]
  [6]]]
 這種寫法得到的結果更加離譜,是一個5 * 5 * 1的三維矩陣,
 每個5 * 1的子矩陣相當於第一種寫法的結果
'''
print(a[range(5), y2])
'''
[[5 9 7 7 6]
 [7 7 7 9 9]
 [5 9 7 7 6]
 [8 8 8 7 6]
 [5 9 7 7 6]]
 一共5行,每一行都是a的某一列,按照y2來取值。
'''

結論就是要使用range和一維數組來進行切片和花式索引。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章