題目描述
輸入n個整數,找出其中最小的K個數。例如輸入4,5,1,6,2,7,3,8這8個數字,則最小的4個數字是1,2,3,4,。
題目鏈接:牛客網
解題思路
快速選擇
- 複雜度:O(N) + O(1)
- 只有當允許修改數組元素時纔可以使用
快速排序的 partition() 方法,會返回一個整數 j 使得 a[l…j-1] 小於等於 a[j],且 a[j+1…h] 大於等於 a[j],此時 a[j] 就是數組的第 j 大元素。可以利用這個特性找出數組的第 K 個元素,這種找第 K 個元素的算法稱爲快速選擇算法。
import java.util.*;
public class Main {
public static void main(String[] args) {
int[] nums = {4,5,1,6,2,7,3,8};
ArrayList list = getLeastNumbers_Solution(nums,4);
printList(list);
}
public static ArrayList<Integer> getLeastNumbers_Solution(int[] nums,int k) {
ArrayList<Integer> list = new ArrayList();
if (k > nums.length || k <= 0) {
return list;
}
findKthSmallest(nums,k - 1);
// findKthSmallest 會改變數組,使得前 k 個數都是最小的 k 個數
for (int i = 0;i < k;i++) {
list.add(nums[i]);
}
return list;
}
public static void findKthSmallest(int[] nums,int k) {
int l = 0;
int h = nums.length - 1;
while (l < h) {
int j = partition(nums, l, h);
if (j == k) {
break;
}
if (j > k) {
h = j - 1;
}else {
l = j + 1;
}
}
}
public static int partition(int[] nums,int l,int h) {
int p = nums[l]; // 切分元素
int i = l, j = h + 1;
while(true) {
while (i != h && nums[++i] < p);
while (j != l && nums[--j] > p);
if (i >= j) {
break;
}
swap(nums, i, j);
}
swap(nums,l,j);
return j;
}
public static void swap(int[] nums, int i, int j) {
int t = nums[i];
nums[i] = nums[j];
nums[j] = t;
}
public static void printList(ArrayList list) {
for (int i = 0;i < list.size();i++) {
System.out.print(list.get(i) + " ");
}
}
}
大小爲 K 的最小堆
- 複雜度:O(NlogK) + O(K)
- 特別適合處理海量數據
應該使用大頂堆來維護最小堆,而不能直接創建一個小頂堆並設置一個大小,企圖讓小頂堆中的元素都是最小元素。
維護一個大小爲 K 的最小堆過程如下:在添加一個元素之後,如果大頂堆的大小大於 K,那麼需要將大頂堆的堆頂元素去除。
import java.util.*;
public class Main {
public static void main(String[] args) {
int[] nums = {4,5,1,6,2,7,3,8};
ArrayList list = getLeastNumbers_Solution(nums,4);
printList(list);
}
public static ArrayList<Integer> getLeastNumbers_Solution(int[] nums,int k) {
ArrayList<Integer> list = new ArrayList();
if (k > nums.length || k <= 0) {
return list;
}
PriorityQueue<Integer> maxHeap = new PriorityQueue<>((o1, o2) -> o2 - o1);
for (int num : nums) {
maxHeap.add(num);
if (maxHeap.size() > k) {
maxHeap.poll();
}
}
return new ArrayList<>(maxHeap);
}
public static void printList(ArrayList list) {
for (int i = 0;i < list.size();i++) {
System.out.print(list.get(i) + " ");
}
}
}