元素選擇問題

問題描述:從n個元素中選擇第k小的元素。此問題存在最壞情況下時間複雜度爲O(n)的算法,但本文不作討論本文介紹一種使用“快速排序”算法的思想求解此問題的方法,平均時間複雜度O(nlogn) 。

“快速排序算法”的實質是“遞歸與分治”,以序列中a[l:r]的某個元素a[i]爲基準將序列分成3部分:a[l:i-1],a[i]和a[i+1,r],

使之滿足a[l:i-1]中的元素都小於a[i],同時a[i+1,r]中的元素都大於a[i]。然後遞歸地對a[l:i-1]和a[i+1,r]進行排序。

此算法的下一個特點是在每一次的“分區”過程中,被選擇作爲“基準”的元素就落在在了其最後的位置上。根據這個特點可以設計出求解“元素選擇”問題的算法。

以POJ 2388 爲例,partition算法實現時,選擇的第一個位置上的元素作爲基準(這種選擇方式實現起來還算比較簡單),當然如果想選用其他位置上的元素也可以直接將它與第一個位置上的元素交換。本次實現中就是選取開頭、中間和結尾“三者取中”的方式,選取的基準越能將所有元素平均分配在兩邊,算法的效率越高。代碼如下:


/* 204K    16MS */
#include <stdio.h>
#include <stdlib.h>

int g_output[10000];

void swap(int a[], int p, int q)
{
    int tmp;

    tmp = a[p];
    a[p] = a[q];
    a[q] = tmp;
}
int partition(int arr[], int l, int r)
{
    int pivotKey;
    int low, high, mid;
    
    //採用頭部、中間、尾部“三者取中”的方式,確定樞軸點
    mid = (l + r) / 2;
    if ((arr[mid] - arr[l])*(arr[mid] - arr[r]) <= 0)
    {
        swap(arr, l, mid);
    }
    else if ((arr[r] - arr[l])*(arr[r] - arr[mid]) <= 0)
    {
        swap(arr, l, r);
    }

    low = l;
    high = r;
    pivotKey = arr[low];

    while (low < high) {
        while (high > low && arr[high] >= pivotKey) {
            high--;
        }
        arr[low] = arr[high];
        while (low < high && arr[low] <= pivotKey) {
            low++;
        }
        arr[high] = arr[low];
    }
    arr[low] = pivotKey;

    return low;
}

int select(int arr[], int l, int r, int k)
{
    int i;

    i = partition(arr, l, r);
    if (k == i) {
        return arr[i];
    } else if (k < i) {
        return select(arr, l, i - 1, k);
    } else {
        return select(arr, i + 1, r, k);
    }
}
int main()
{
    int n;
    int i, rel;
    
    scanf("%d", &n);
    for (i = 0; i < n; ++i) {
        scanf("%d", &g_output[i]);
    }
    rel = select(g_output, 0, n - 1, (n-1)/2);
    printf("%d", rel);
}

下面再提供一個使用C++ 的nth_element函數實現的代碼: /* 204K, 0MS*/ 由此可見C++中nth_element的實現還是很高效的。

#include <iostream>
#include <cstdio>
#include <algorithm>

using namespace std;

int g_output[10000];

int main()
{
    int n;
    int i;

    scanf("%d", &n);
    for (i = 0; i < n; ++i) {
        scanf("%d", &g_output[i]);
    }

    nth_element(g_output, g_output + (n-1)/2, g_output + n);
    printf("%d", g_output[(n-1)/2]);
}


發佈了44 篇原創文章 · 獲贊 7 · 訪問量 8萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章