5021 Revenge of kNN II
Time Limit: 8000/5000 MS (Java/Others) Memory Limit: 32768/32768 K (Java/Others)Total Submission(s): 196 Accepted Submission(s): 56
In k-NN regression, the output is the property value for the object. This value is the average of the values of its k nearest neighbors.
---Wikipedia
Today, kNN takes revenge on you, again. You have to handle a kNN case in one-dimensional coordinate system. There are N points with a position Xi and value Vi. Then there are M kNN queries for point with index i, recalculate its value by averaging the values its k-Nearest Neighbors. Note you have to replace the value of i-th point with the new calculated value. And if there is a tie while choosing k-Nearest Neighbor, choose the one with the minimal index first.
(Have you ever tried the problem “Revenge of kNN”? They are twin problems!)
Each test case begins with two integers N and M. Then N lines follows, each line contains two integers Xi and Vi. Then M lines with the queried index Qi and Ki follows, in which Ki indicating the number of k-Nearest Neighbors
[Technical Specification]
1. 1 <= T <= 5
2. 2 <= N <= 100 000
3. 1 <= M <= 100 000
4. 1 <= Vi <= 1 000
5. 1 <= Xi <= 1 000 000 000, and no two Xi are identical.
6. 1 <= Qi <= N
7. 1 <= Ki <= N - 1
在二分的時候注意K和K+1可能都是符合條件的,如果算出K+1被捨棄的話,減小Distance可能得到的是K-1,並不連續,所以要判斷一下這種情況。
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
const int MAXN = 100005;
typedef struct node{
int id;
int x;
int v;
}node;
node a[MAXN];
int n, m;
int index[MAXN];
double C[MAXN];
int L, R;
int q, k;
bool cmp(node a, node b){
return a.x<b.x;
}
int lowbit(int x){
return x&(-x);
}
void add(int i, double v){
while(i<=n){
C[i] += v;
i += lowbit(i);
}
}
double Sum(int i){
double res = 0;
while(i>0){
res += C[i];
i -= lowbit(i);
}
return res;
}
int findL(int x){
int l = 1, r = n, m, res1;
while(l<=r){
m = (l+r)>>1;
if(a[m].x>=x){
res1 = m;
r = m-1;
}
else l = m+1;
}
return res1;
}
int findR(int x){
int l = 1, r = n, m, res2;
while(l<=r){
m = (l+r)>>1;
if(a[m].x<=x){
res2 = m;
l = m+1;
}
else r = m-1;
}
return res2;
}
void findLR(){
int l, r, mid;
l = a[1].x;
r = a[n].x;
while(l<=r){
mid = (l+r)>>1;
L = findL(a[q].x-mid);
R = findR(a[q].x+mid);
if(R-L<k) l = mid+1;
else if(R-L>k+1) r = mid-1;
else if(R-L==k){
return;
}else if(R-L==k+1){
if(a[q].x-a[L].x == a[R].x-a[q].x){
if(a[L].id<a[R].id)
R--;
else
L++;
}else if(a[q].x-a[L].x<a[R].x-a[q].x)
R--;
else
L++;
return;
}
}
}
int main(){
int T;
double ans;
scanf("%d", &T);
while(T--){
scanf("%d %d", &n, &m);
for(int i=1;i<=n;i++){
scanf("%d %d", &a[i].x, &a[i].v);
a[i].id = i;
}
sort(a+1, a+n+1, cmp);
memset(C, 0, sizeof(C));
for(int i=1;i<=n;i++){
index[a[i].id] = i;
add(i, a[i].v);
}
ans = 0.0;
while(m--){
scanf("%d %d", &q, &k);
q = index[q];
findLR();
double s = Sum(R) - Sum(L-1);
double t = Sum(q) - Sum(q-1);
add(q, (s-t)/k-t);
ans += (s-t)/k;
}
printf("%.3f\n", ans);
}
return 0;
}
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
const int MAXN = 100005;
typedef struct node{
int id;
int x;
int v;
}node;
node a[MAXN];
int n, m;
int index[MAXN];
double C[MAXN];
int L, R;
int q, k;
bool cmp(node a, node b){
return a.x<b.x;
}
int lowbit(int x){
return x&(-x);
}
void add(int i, double v){
while(i<=n){
C[i] += v;
i += lowbit(i);
}
}
double Sum(int i){
double res = 0;
while(i>0){
res += C[i];
i -= lowbit(i);
}
return res;
}
int findL(int x){
int l = 1, r = n, m;
while(l<r){
m = (l+r)>>1;
if(a[m].x>=x) r = m;
else l = m+1;
}
return l;
}
int findR(int x){
int l = 1, r = n, m;
while(l<r){
m = (l+r+1)>>1;
if(a[m].x<=x) l = m;
else r = m-1;
}
return r;
}
void findLR(){
int l, r, mid;
l = a[1].x;
r = a[n].x;
while(l<=r){
mid = (l+r)>>1;
L = findL(a[q].x-mid);
R = findR(a[q].x+mid);
if(R-L<k) l = mid+1;
else if(R-L>k+1) r = mid-1;
else if(R-L==k){
return;
}else if(R-L==k+1){
if(a[q].x-a[L].x == a[R].x-a[q].x){
if(a[L].id<a[R].id)
R--;
else
L++;
}else if(a[q].x-a[L].x<a[R].x-a[q].x)
R--;
else
L++;
return;
}
}
}
int main(){
int T;
double ans;
scanf("%d", &T);
while(T--){
scanf("%d %d", &n, &m);
for(int i=1;i<=n;i++){
scanf("%d %d", &a[i].x, &a[i].v);
a[i].id = i;
}
sort(a+1, a+n+1, cmp);
memset(C, 0, sizeof(C));
for(int i=1;i<=n;i++){
index[a[i].id] = i;
add(i, a[i].v);
}
ans = 0.0;
while(m--){
scanf("%d %d", &q, &k);
q = index[q];
findLR();
double s = Sum(R) - Sum(L-1);
double t = Sum(q) - Sum(q-1);
add(q, (s-t)/k-t);
ans += (s-t)/k;
}
printf("%.3f\n", ans);
}
return 0;
}