2019hdu多校六 Ridiculous Netizens(点分治)

大概题意是: 给你一颗无根树,每一个结点有点权, 有多少颗子树的结点乘积不超过m?子树的定义是树上的连通块。

首先我们考虑另一个问题,假设给你一颗有根树,所以包含根的子树有多少种满足乘积不超过m?

考虑树形dp的做法,定义dp[i][j]是在i被选取后 这颗子树中乘积为j的子树方案数,这样每次将两颗子树合并的复杂度是m*m的。但实际上子树大小限制了状态数不会那么多,所以每次计算一个点的贡献的复杂度是 o(m)的。

这里题解有一个很巧妙的求法,因为如果一个点被选取,那么他的父亲一定被选取。如果一个点不被选取,那么它子树中的所有点也不会被选取。所以当我们从fa->u时,dp[fa]已经包含了u不被选取的情况(因为u还没有被搜索,没有任何节点统计了贡献)。那么我们只需要统计u被选取的情况,因为u被选取了,那么fa一定也被选取了,所以u可以继承fa的信息,然后继续递归即可。

但是这样dfs一次的复杂度是n*m的,还是有点吃不消。注意题目中要求的是乘积。我们可以考虑剩余的子树大小,比如当前有两颗乘积为m,m-1的子树,那么他们都只能在添加大小为1的子树,所以状态可以合并。其实就是一个整除分块。所以dp[i][j]表示i点在选取后还能添加大小为j的子树的方案数,这样状态数只有^{\sqrt{m}}。复杂度是o(n*^{\sqrt{m}}).

包含根的子树统计后,只要太统计一遍不包含根情况,只要把根标记一下,递归根的子树,这里就可以点分治了。

#include <bits/stdc++.h>

using namespace std;

#define N 2025
#define ll long long
#define mod 1000000007
#define go(i,a,b) for(int i=(a);i<=(b);i++)
#define dep(i,a,b) for(int i=(a);i>=(b);i--)
#define pb push_back
#define inf 0x3f3f3f3f
#define ld long double
#define pii pair<int,int>
#define vi vector<int>
#define add(a,b) (a+=(b)%mod)%=mod
#define lowb(x,c,len) lower_bound(c+1,c+len+1,x)-c
#define uppb(x,c,len) upper_bound(c+1,c+len+1,x)-c
#define ls i*2+1
#define rs i*2+2
#define mid (l+r)/2
#define lson l,mid,ls
#define rson mid+1,r,rs
int n,m,sz,cnt,tot,root,ans,las;
int h[N],sum[N],mson[N],vis[N],a[N],w[N],f[N*N],dp[N][N];
struct no{
    int to,n;
};no eg[N*2];
void link(int u,int to){
    eg[++tot]={to,h[u]};h[u]=tot;
    eg[++tot]={u,h[to]};h[to]=tot;
}
void getroot(int u,int fa){
    sum[u]=1;mson[u]=0;
    for(int i=h[u];i;i=eg[i].n){
        int to=eg[i].to;
        if(to==fa||vis[to])continue;
        getroot(to,u);
        sum[u]+=sum[to];
        mson[u]=max(mson[u],sum[to]);
    }
    mson[u]=max(mson[u],sz-sum[u]);
    if(mson[u]<mson[root])root=u;
}
void dfs(int u,int fa){
    go(i,1,cnt)dp[u][i]=0;
    go(i,1,cnt)if(w[i]>=a[u])add(dp[u][f[w[i]/a[u]]],dp[fa][i]);
    for(int i=h[u];i;i=eg[i].n){
        int to=eg[i].to;
        if(to==fa||vis[to])continue;
        dfs(to,u);
        go(i,1,cnt)add(dp[u][i],dp[to][i]);
    }
}
void divide(int u){
    //cout<<u<<endl;
     dp[0][cnt]=1;dfs(u,0);vis[u]=1;
    go(i,1,cnt)add(ans,dp[u][i]);
    for(int i=h[u];i;i=eg[i].n){
        int to=eg[i].to;
        if(vis[to])continue;
        mson[root=0]=sz=sum[to];
        getroot(to,0);divide(root);
    }
}
void solve(){
    mson[root=0]=sz=n;
    getroot(1,-1);divide(root);
    printf("%d\n",ans);
    ans=tot=cnt=las=0;
    go(i,1,n)h[i]=vis[i]=0;
    memset(dp[0],0,sizeof dp[0]);
}
int main()
{
    int T,u,to;cin>>T;while(T--){
        scanf("%d%d",&n,&m);
        dep(i,m,1){
            int x=m/i;
            w[f[x]=(x!=las?++cnt:cnt)]=x;
            las=x;
        }
        go(i,1,n)scanf("%d",&a[i]);
        go(i,2,n)scanf("%d%d",&u,&to),link(u,to);
        solve();
    }
    return 0;
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章