HDOJ 5021 Revenge of kNN II

5021 Revenge of kNN II

Time Limit: 8000/5000 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Others)
Total Submission(s): 196    Accepted Submission(s): 56


Problem Description
In pattern recognition, the k-Nearest Neighbors algorithm (or k-NN for short) is a non-parametric method used for classification and regression. In both cases, the input consists of the k closest training examples in the feature space.
In k-NN regression, the output is the property value for the object. This value is the average of the values of its k nearest neighbors.
---Wikipedia

Today, kNN takes revenge on you, again. You have to handle a kNN case in one-dimensional coordinate system. There are N points with a position Xi and value Vi. Then there are M kNN queries for point with index i, recalculate its value by averaging the values its k-Nearest Neighbors. Note you have to replace the value of i-th point with the new calculated value. And if there is a tie while choosing k-Nearest Neighbor, choose the one with the minimal index first.
(Have you ever tried the problem “Revenge of kNN”? They are twin problems!)
 

Input
The first line contains a single integer T, indicating the number of test cases. 

Each test case begins with two integers N and M. Then N lines follows, each line contains two integers Xi and Vi. Then M lines with the queried index Qi and Ki follows, in which Ki indicating the number of k-Nearest Neighbors

[Technical Specification]
1. 1 <= T <= 5
2. 2 <= N <= 100 000
3. 1 <= M <= 100 000
4. 1 <= Vi <= 1 000
5. 1 <= Xi <= 1 000 000 000, and no two Xi are identical.
6. 1 <= Qi <= N
7. 1 <= Ki <= N - 1
 

Output
For each test case, output sum of all queries rounded to three fractional digits.
 

Sample Input
1 5 3 1 2 2 3 3 6 4 8 5 8 2 2 3 2 4 2
 

Sample Output
17.000
Hint
For the first query, the 2-NN for point 2 is point 1 and 3, so the new value is (2 + 6) / 2 = 4. For the second query, the 2-NN for point 3 is point 2 and 4, and the value of point 2 is changed to 4 by the last query, so the new value is (4 + 8) / 2 = 6. Huge input, faster I/O method is recommended.
 

官方思路:
考慮如何快速求出距離最近的k個點的權值之和,這裏的距離具有明顯的二分性。這樣可以在log(MAXX)的時間內求出k個點的座標範圍。求出之後的問題是,區間求和,單點更新,樹狀數組足夠解決這個問題了。
在二分的時候注意K和K+1可能都是符合條件的,如果算出K+1被捨棄的話,減小Distance可能得到的是K-1,並不連續,所以要判斷一下這種情況。


代碼如下:(二分搜索有些難寫,汗!!)
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;

const int MAXN = 100005;
typedef struct node{
    int id;
    int x;
    int v;
}node;

node a[MAXN];

int n, m;
int index[MAXN];
double C[MAXN];
int L, R;
int q, k;

bool cmp(node a, node b){
    return a.x<b.x;
}

int lowbit(int x){
    return x&(-x);
}

void add(int i, double v){
    while(i<=n){
        C[i] += v;
        i += lowbit(i);
    }
}

double Sum(int i){
    double res = 0;
    while(i>0){
        res += C[i];
        i -= lowbit(i);
    }
    return res;
}

int findL(int x){
    int l = 1, r = n, m, res1;
    while(l<=r){
        m = (l+r)>>1;
        if(a[m].x>=x){
            res1 = m;
            r = m-1;
        }
        else l = m+1;
    }
    return res1;
}

int findR(int x){
    int l = 1, r = n, m, res2;
    while(l<=r){
        m = (l+r)>>1;
        if(a[m].x<=x){
            res2 = m;
            l = m+1;
        }
        else r = m-1;
    }
    return res2;
}

void findLR(){
    int l, r, mid;
    l = a[1].x;
    r = a[n].x;

    while(l<=r){
        mid = (l+r)>>1;
        L = findL(a[q].x-mid);
        R = findR(a[q].x+mid);

        if(R-L<k) l = mid+1;
        else if(R-L>k+1) r = mid-1;
        else if(R-L==k){
            return;
        }else if(R-L==k+1){
            if(a[q].x-a[L].x == a[R].x-a[q].x){
                if(a[L].id<a[R].id)
                    R--;
                else
                    L++;
            }else if(a[q].x-a[L].x<a[R].x-a[q].x)
                R--;
            else
                L++;
            return;
        }
    }
}

int main(){
    int T;
    double ans;
    scanf("%d", &T);
    while(T--){
        scanf("%d %d", &n, &m);
        for(int i=1;i<=n;i++){
            scanf("%d %d", &a[i].x, &a[i].v);
            a[i].id = i;
        }

        sort(a+1, a+n+1, cmp);
        memset(C, 0, sizeof(C));
        for(int i=1;i<=n;i++){
            index[a[i].id] = i;
            add(i, a[i].v);
        }

        ans = 0.0;
        while(m--){
            scanf("%d %d", &q, &k);
            q = index[q];
            findLR();
            double s = Sum(R) - Sum(L-1);
            double t = Sum(q) - Sum(q-1);
            add(q, (s-t)/k-t);
            ans += (s-t)/k;
        }
        printf("%.3f\n", ans);
    }
    return 0;
}

另一種二分搜索代碼:
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;

const int MAXN = 100005;
typedef struct node{
    int id;
    int x;
    int v;
}node;

node a[MAXN];

int n, m;
int index[MAXN];
double C[MAXN];
int L, R;
int q, k;

bool cmp(node a, node b){
    return a.x<b.x;
}

int lowbit(int x){
    return x&(-x);
}

void add(int i, double v){
    while(i<=n){
        C[i] += v;
        i += lowbit(i);
    }
}

double Sum(int i){
    double res = 0;
    while(i>0){
        res += C[i];
        i -= lowbit(i);
    }
    return res;
}

int findL(int x){
    int l = 1, r = n, m;
    while(l<r){
        m = (l+r)>>1;
        if(a[m].x>=x) r = m;
        else l = m+1;
    }
    return l;
}

int findR(int x){
    int l = 1, r = n, m;
    while(l<r){
        m = (l+r+1)>>1;
        if(a[m].x<=x) l = m;
        else r = m-1;
    }
    return r;
}

void findLR(){
    int l, r, mid;
    l = a[1].x;
    r = a[n].x;

    while(l<=r){
        mid = (l+r)>>1;
        L = findL(a[q].x-mid);
        R = findR(a[q].x+mid);

        if(R-L<k) l = mid+1;
        else if(R-L>k+1) r = mid-1;
        else if(R-L==k){
            return;
        }else if(R-L==k+1){
            if(a[q].x-a[L].x == a[R].x-a[q].x){
                if(a[L].id<a[R].id)
                    R--;
                else
                    L++;
            }else if(a[q].x-a[L].x<a[R].x-a[q].x)
                R--;
            else
                L++;
            return;
        }
    }
}

int main(){
    int T;
    double ans;
    scanf("%d", &T);
    while(T--){
        scanf("%d %d", &n, &m);
        for(int i=1;i<=n;i++){
            scanf("%d %d", &a[i].x, &a[i].v);
            a[i].id = i;
        }

        sort(a+1, a+n+1, cmp);
        memset(C, 0, sizeof(C));
        for(int i=1;i<=n;i++){
            index[a[i].id] = i;
            add(i, a[i].v);
        }

        ans = 0.0;
        while(m--){
            scanf("%d %d", &q, &k);
            q = index[q];
            findLR();
            double s = Sum(R) - Sum(L-1);
            double t = Sum(q) - Sum(q-1);
            add(q, (s-t)/k-t);
            ans += (s-t)/k;
        }
        printf("%.3f\n", ans);
    }
    return 0;
}


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章