Little M’s attack plan
題意:給一顆帶有權值的樹,然後q個詢問,每個詢問x和k,回答與點x的距離在k以內的所有點的權值和。
思路:容斥原理,以下是對於每個詢問的回答,dp[i][j]代表點i的子樹中的所有點距離點i在j個單位距離的點的權值和。
一開始開了個dp[MAX_N][110]的數組,數組太大而且會T。
scanf("%d%d",&x,&k);
long long ans=0;
int pre=x;
for(j=k;j>=0;j--){
ans+=dp[pre][j];
if(pre==1)
break;
if(j>=2)
ans-=dp[pre][j-2];
pre=fa[pre];
}
printf("%lld\n",ans);
然後因爲q詢問一共5000個,可以離線看哪些dp[i][j]可以用到,然後跑dfs,樹狀數組存儲遍歷到的所有點某個深度的權值和,跑到一個點,如果這個點有詢問,就進行查詢當前點的深度deep[x]到deep[x]+k的所有點的權值和,然後當跑完子樹再回溯到這個點時再查詢一遍,這時候那這個值減去之前的值就是這個點的子樹上的答案,即dp[x][k]。
如果是用vector,注意哪些dp[i][j]已經被放在vector裏面了,防止重複計算,因爲容易出錯。去重!!!
#include<iostream>
#include<cstdio>
#include<map>
#include<vector>
using namespace std;
const int MAX_N=1010000;
map<int,long long>ma[MAX_N];
map<int,bool>mb[MAX_N];
int head[MAX_N],ver[2*MAX_N],Next[2*MAX_N];
int tot;
vector<int>v[MAX_N];
void Add(int x,int y){
ver[++tot]=y;Next[tot]=head[x];head[x]=tot;
}
long long sum[MAX_N],a[MAX_N];
int nn;
int deep[MAX_N];
int fa[MAX_N];
struct skt{
int x,k;
}b[5010];
void add(int p,long long x){
while(p<=nn){
sum[p]+=x;
p+=p&-p;
}
}
long long ask(int p){
long long ans=0;
while(p){
ans+=sum[p];
p-=p&-p;
}
return ans;
}
void dfs1(int x){
for(int i=head[x];i;i=Next[i]){
int y=ver[i];
if(y==fa[x])
continue;
fa[y]=x;
deep[y]=deep[x]+1;
nn=max(nn,deep[y]);
dfs1(y);
}
}
void dfs2(int x){
int i;
for(i=0;i<v[x].size();i++){
int k=v[x][i];
ma[x][k]=ask(deep[x]+k)-ask(deep[x]-1);
}
add(deep[x],a[x]);
for(i=head[x];i;i=Next[i]){
int y=ver[i];
if(y==fa[x])
continue;
dfs2(y);
}
for(i=0;i<v[x].size();i++){
int k=v[x][i];
ma[x][k]=ask(deep[x]+k)-ask(deep[x]-1)-ma[x][k];
}
}
int main(void){
int n,i,j,x,y,q;
scanf("%d",&n);
for(i=1;i<=n;i++){
scanf("%lld",&a[i]);
}
for(i=1;i<n;i++){
scanf("%d%d",&x,&y);
Add(x,y);
Add(y,x);
}
deep[1]=1;
dfs1(1);
nn+=110;
scanf("%d",&q);
for(i=0;i<q;i++){
scanf("%d%d",&b[i].x,&b[i].k);
int pre=b[i].x;
for(j=b[i].k;j>=0;j--){
if(!mb[pre][j]){
v[pre].push_back(j);
mb[pre][j]=true;
}
if(pre==1)
break;
if(j>=2){
if(!mb[pre][j-2]){
v[pre].push_back(j-2);
mb[pre][j-2]=true;
}
}
pre=fa[pre];
}
}
dfs2(1);
for(i=0;i<q;i++){
long long ans=0;
int pre=b[i].x;
for(j=b[i].k;j>=0;j--){
ans+=ma[pre][j];
if(pre==1)
break;
if(j>=2)
ans-=ma[pre][j-2];
pre=fa[pre];
}
printf("%lld\n",ans);
}
return 0;
}