Numpy的argpartion函數

Numpy的argpartion函數

一、np.argpartion()

Numpy的argpartion函數來源於快排算法中的一個典型操作partition,即根據一個數值x,把數組中的元素劃分成兩半,使得index前面的元素都不大於x,index後面的元素都不小於x。
np.argpartition不對原數組操作,只返回一個按照上述操作排序過後的index。通過這個函數可以高效地找到 N 個最大值的索引並返回 N 個值。在給出索引後,我們可以根據需要進行值排序。

import numpy as np
array = np.array([10, 7, 4, 3, 2, 2, 5, 9, 0, 4, 6, 0])
#返回一個索引,比原數組第5大(從0開始)的數小的數在這個數之前,比這個數大的數在它之後。
index = np.argpartition(array, 4)
#輸出,新索引
print(index)
#[ 4 11  8  5  3  2  9  6  1 10  7  0]
#按這個新索引可以重新排列數組
print(array[index])
#[ 2  0  0  2  3  4  4  5  7  6  9 10]
#第5大的數是3,比3小的在3之前,比3大的在3之後

二、輸出top5

np.argpartition的一個重要應用就是高效輸出最大的幾個值(如top-5),因爲不用像np.sort對所有元素排序。

#還是上邊那個數組,輸出top5
array = np.array([10, 7, 4, 3, 2, 2, 5, 9, 0, 4, 6, 0])
array[np.argpartition(array, -5)[-5:]]
#輸出:[ 5,  7,  6,  9, 10]

先排序,比倒數第5個數大的,都排在這個數後邊,返回索引,然後用[-5:]把最後5個(最大的5個的索引)取出來,然後再array[index]就把最大的5個數(top5)取出來了。此法非常高效,因爲不用對原數組操作,只返回索引。

參考:
1.numpy中的argpartition
2.numpy中的argpartition用法
3.Argpartition:在數組中找到最大的 N 個元素。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章