Numpy的argpartion函數
一、np.argpartion()
Numpy的argpartion函數來源於快排算法中的一個典型操作partition,即根據一個數值x,把數組中的元素劃分成兩半,使得index前面的元素都不大於x,index後面的元素都不小於x。
np.argpartition不對原數組操作,只返回一個按照上述操作排序過後的index。通過這個函數可以高效地找到 N 個最大值的索引並返回 N 個值。在給出索引後,我們可以根據需要進行值排序。
import numpy as np
array = np.array([10, 7, 4, 3, 2, 2, 5, 9, 0, 4, 6, 0])
#返回一個索引,比原數組第5大(從0開始)的數小的數在這個數之前,比這個數大的數在它之後。
index = np.argpartition(array, 4)
#輸出,新索引
print(index)
#[ 4 11 8 5 3 2 9 6 1 10 7 0]
#按這個新索引可以重新排列數組
print(array[index])
#[ 2 0 0 2 3 4 4 5 7 6 9 10]
#第5大的數是3,比3小的在3之前,比3大的在3之後
二、輸出top5
np.argpartition的一個重要應用就是高效輸出最大的幾個值(如top-5),因爲不用像np.sort對所有元素排序。
#還是上邊那個數組,輸出top5
array = np.array([10, 7, 4, 3, 2, 2, 5, 9, 0, 4, 6, 0])
array[np.argpartition(array, -5)[-5:]]
#輸出:[ 5, 7, 6, 9, 10]
先排序,比倒數第5個數大的,都排在這個數後邊,返回索引,然後用[-5:]把最後5個(最大的5個的索引)取出來,然後再array[index]就把最大的5個數(top5)取出來了。此法非常高效,因爲不用對原數組操作,只返回索引。
參考:
1.numpy中的argpartition
2.numpy中的argpartition用法
3.Argpartition:在數組中找到最大的 N 個元素。