Numpy的argpartion函数

Numpy的argpartion函数

一、np.argpartion()

Numpy的argpartion函数来源于快排算法中的一个典型操作partition,即根据一个数值x,把数组中的元素划分成两半,使得index前面的元素都不大于x,index后面的元素都不小于x。
np.argpartition不对原数组操作,只返回一个按照上述操作排序过后的index。通过这个函数可以高效地找到 N 个最大值的索引并返回 N 个值。在给出索引后,我们可以根据需要进行值排序。

import numpy as np
array = np.array([10, 7, 4, 3, 2, 2, 5, 9, 0, 4, 6, 0])
#返回一个索引,比原数组第5大(从0开始)的数小的数在这个数之前,比这个数大的数在它之后。
index = np.argpartition(array, 4)
#输出,新索引
print(index)
#[ 4 11  8  5  3  2  9  6  1 10  7  0]
#按这个新索引可以重新排列数组
print(array[index])
#[ 2  0  0  2  3  4  4  5  7  6  9 10]
#第5大的数是3,比3小的在3之前,比3大的在3之后

二、输出top5

np.argpartition的一个重要应用就是高效输出最大的几个值(如top-5),因为不用像np.sort对所有元素排序。

#还是上边那个数组,输出top5
array = np.array([10, 7, 4, 3, 2, 2, 5, 9, 0, 4, 6, 0])
array[np.argpartition(array, -5)[-5:]]
#输出:[ 5,  7,  6,  9, 10]

先排序,比倒数第5个数大的,都排在这个数后边,返回索引,然后用[-5:]把最后5个(最大的5个的索引)取出来,然后再array[index]就把最大的5个数(top5)取出来了。此法非常高效,因为不用对原数组操作,只返回索引。

参考:
1.numpy中的argpartition
2.numpy中的argpartition用法
3.Argpartition:在数组中找到最大的 N 个元素。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章