poj 2442(堆的stl實現)

Sequence
Time Limit: 6000MS   Memory Limit: 65536K
Total Submissions: 6281   Accepted: 1962

Description

Given m sequences, each contains n non-negative integer. Now we may select one number from each sequence to form a sequence with m integers. It's clear that we may get n ^ m this kind of sequences. Then we can calculate the sum of numbers in each sequence, and get n ^ m values. What we need is the smallest n sums. Could you help us?

Input

The first line is an integer T, which shows the number of test cases, and then T test cases follow. The first line of each case contains two integers m, n (0 < m <= 100, 0 < n <= 2000). The following m lines indicate the m sequence respectively. No integer in the sequence is greater than 10000.

Output

For each test case, print a line with the smallest n sums in increasing order, which is separated by a space.

Sample Input

1
2 3
1 2 3
2 2 3

Sample Output

3 3 4

Source

POJ Monthly,Guang Lin


首先要深入理解題意:

輸入是m行*n列的矩陣。需要從每一行選一個數字,這樣會構成m個數字的序列。

對這個序列求和,問最小的n個和是什麼。


首先,要深刻理解,我們需要維護多少個數字。因爲題目中只要求前n個最小和,我們只需要每步維護n個最小的數字即可。

比如第一行和第二行各挑一個數字求和,可以形成n*n個數字,排個序後,我們只需要保存前n個數字即可,根本不用管其餘的數。第三行和上一步形成的這n個數字分別求和,又形成了n*n個數字,排個序後,再只保存前n個數字即可。因此可以逐步迭代計算。


如果按照上面的方法,寫一個樸素的暴力算法,則空間複雜度爲n*n,時間複雜度爲(n*n + n*log(n)) * m 。可是TLE了。。

/*Source Code
Problem: 2442		User: 775700879
Memory: N/A		Time: N/A
Language: G++		Result: Time Limit Exceeded

    Source Code*/

    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #define oo 1000000
    #define MAXV 500 * 2
    using namespace std;
    vector result;
    int temp[2000*2000+10];
    int last[2010];
    int line[2010];
    int main() {
        int t;
        scanf("%d", &t);
        while (t--) {
            int n, m;
            scanf("%d%d", &m, &n);
            int i, j, k;
            for (i = 0; i < n; i++) {
                scanf("%d", last+i);
            }
            for (i = 1; i < m; i++) {
                for (j = 0; j < n; j++) {
                    scanf("%d", line+j);
                }
                int count = 0;
                for (j = 0; j < n; j++) {
                    for (k = 0; k < n; k++) {
                        temp[count++] = last[j] + line[k];
                    }
                }
                sort(temp, temp+count);
                for (j = 0; j < n; j++) {
                    last[j] = temp[j];
                }
            }
            for (i = 0; i < n; i++) {
                printf("%d ", last[i]);
            }
            printf("\n");
        }
        return 0;
    }


需要優化。想到用堆。因爲事實上我們每步都只需要維護最小的n個數即可。那我們就維護一個大跟堆(注意這裏是大跟堆),這個堆含有固定的n個元素(意味着初始化的時候,堆中有n個元素。以後只要插入一個元素,就要刪除一個元素來保證堆中元素數目恆定爲n)。
這個思路有了以後,再細想一下步驟:每一步先把thisline[0] 和 [lastheap[0], lastheap[n-1]] 分別相加,形成n個數字,建堆thisheap。再用[thisline[1]-thisline[n-1]]與[lastheap[0], lastheap[n-1]]分別求和,依次與thisheap的堆頂(最大值)比較,如果發現比最大值大,則捨棄。發現比最大值小,則刪除堆中最大值,將新值push到堆中。

堆的幾個重要的stl操作:
stl中的堆都是 大根堆
make_heap(v.begin(), v.end()); 將n個元素的數組建堆
pop_heap(v.begin(), v.end()); 將n個元素的堆的堆頂pop掉。此時數組中的0-n-2是一個有n-1個元素的新堆。被pop掉的原堆頂元素被放到了n-1的位置。因此往往在這一步之後還需要執行 v.pop_back();
push_heap(v.begin(), v.end()); 這一步首先要確保數組中的0-n-2是一個有n-1個元素的堆,而且新元素已經被放置在了v[n-1]位置。調用這個操作會將v[n-1] shiftup到合適的位置。因此在這一步之前要先執行v.push_front(x);
這麼做的話,ac需要2600ms+, 如果加一個小小的優化:將thisline先排序,當發現求和大於堆頂元素時,後面的就都不用比較了。這麼優化可以400+ms AC。
代碼如下:


/*Source Code
Problem: 2442		User: 775700879
Memory: 724K		Time: 407MS
Language: G++		Result: Accepted

    Source Code*/

    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #define oo 1000000
    #define MAXV 500 * 2
    using namespace std;
    vector result;
    int temp[2000*2000+10];
    int last[2010];
    int cur[2010];
    int line[2010];
    int main() {
        int t;
        scanf("%d", &t);
        while (t--) {
            int n, m;
            scanf("%d%d", &m, &n);
            int i, j, k;
            for (i = 0; i < n; i++) {
                scanf("%d", last+i);
            }
            make_heap(last, last+n);
            for (i = 1; i < m; i++) {
                for (j = 0; j < n; j++) {
                    scanf("%d", line+j);
                }
                sort(line, line+n);
                for (j = 0; j < n; j++) {
                    cur[j] = last[j] + line[0];
                } //保證先向堆中插入 n 個元素。
                make_heap(cur, cur+n);//建堆
                for (j = 0; j < n; j++) {
                    for (k = 1; k < n; k++) {
                        if (last[j] + line[k] > cur[0]) 
                            break;
                        pop_heap(cur, cur+n); //將最大的元素pop掉 shiftdown
                        cur[n-1] = last[j] + line[k];
                        push_heap(cur, cur+n);//將新的元素push到堆裏 shiftup
                        //在這個循環中,堆元素始終沒有發生變化,一直是n
                    }
                }
                for (j = 0; j < n; j++) {
                    last[j] = cur[j];
                }
            }
            sort(last, last + n);
            for (i = 0; i < n; i++) {
                printf("%d ", last[i]);
            }
            printf("\n");
        }
        return 0;
    }


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章