Forest Game 【期望】【点分治】【FFT】

题目链接:https://vjudge.net/problem/Gym-101234D

题目大意:给一颗树,共N个点,每次随机选择一个点,得分加上该点所在树的大小,然后删除这个点,断开与其相连的所有边,问删完所有点所获得的   期望得分乘N! 是多少。

emmm被这个乘N! 给坑了,以为这是个假的期望题目,因为全排列共N!种, 期望乘N! 其实就是全排列的和,觉得这是出题人留的解题思路.....然后就去考虑和的问题了,从此走上了不归路。

 

看了题解后发现,这个N! 一点用都没有啊啊啊啊啊啊啊。

重新用期望的方式开始思考,我们考虑节点u对节点v的贡献,u对v有贡献当且仅当u是u->v这条链上第一个删除的点,概率为1/len(len是该链上点的个数,包含端点)。

然后就变成了求树上每种长度(长度定义为路径上点数)的路径条数的模板,除了长度为1的以外,都需要算2次。

点分治的过程中使用fft合并就可以了。

复杂度O (N * logN * logN)

#include <bits/stdc++.h>
#define ll long long
#define rep(i, a, b) for(int i = (a); i <= (b); i++)
#define pb push_back
#define ll long long
using namespace std;
const int M = 1e5+1000;
const int mod = 1e9+7;
const double PI = acos(-1.0);
struct Complex  {
    double x, y;
    Complex(double _x=0.0,double _y = 0.0) {
        x = _x;
        y = _y;
    }
    Complex operator -(const Complex &b) const {
        return Complex(x-b.x, y-b.y);
    }
    Complex operator +(const Complex &b) const {
        return Complex(x+b.x, y+b.y);
    }
    Complex operator *(const Complex &b) const {
        return Complex(x*b.x-y*b.y, x*b.y+y*b.x);
    }
};
void change(Complex y[], int len) {
    int i,j,k;
    for(i = 1, j = len/2; i < len-1; i++) {
        if(i < j) swap(y[i],y[j]);
        k = len/2;
        while(j >= k) {
            j -= k;
            k /= 2;
        }
        if(j < k) j += k;
    }
}
void fft(Complex y[], int len, int on) {
    change(y, len);
    for(int h = 2; h <= len; h <<= 1) {
        Complex wn(cos(-on*2*PI/h),sin(-on*2*PI/h));
        for(int j = 0; j <= len; j+= h) {
            Complex w(1,0);
            for(int k = j; k < j+h/2; k++) {
                Complex u = y[k];
                Complex t = w*y[k+h/2];
                y[k] = u+t;
                y[k+h/2] = u-t;
                w = w*wn;

            }
        }
    }
    if(on == -1)
        for(int i = 0; i < len; i++)
            y[i].x /= len;
}
const int N = 800010; //表达式长度的八倍
Complex x1[N],x2[N];
ll sum[N],Tnum[N],num[N];
void mul(ll a[],ll b[],ll sum[],int n,int m) {
    int len1 = n+1;
    int len2 = m+1;
    int len = 1;
    while(len < len1*2 || len < len2*2) len <<= 1;
    for(int i = 0; i < len1; i++) x1[i] = Complex(a[i],0);
    for(int i = len1; i < len; i++) x1[i] = Complex(0,0);
    for(int i = 0; i < len2; i++) x2[i] = Complex(b[i],0);
    for(int i = len2; i < len; i++) x2[i] = Complex(0,0);
    fft(x1,len,1);
    fft(x2,len,1);
    for(int i = 0; i < len; i++) x1[i] = x1[i]*x2[i];
    fft(x1,len,-1);
    for(int i = 0; i < len1+len2-1; i++) sum[i] += (ll)(x1[i].x+0.5);
}
//--------------------------------------------------


int n;
struct node {
    int v,nxt;
}edge[2*M];
int tot,head[M];
void ae(int u,int v) {
    edge[++tot] = node{v,head[u]};
    head[u] = tot;
}
void init(int n) {
    tot = 0;
    rep(i, 1, n) head[i] = -1;
}
int siz[M],Root,wt[M],Tsiz,len1,len2;
bool vis[M];

void GetRoot(int u,int f) {
    siz[u] = 1;
    wt[u] = 0;
    for(int i = head[u]; ~i ; i = edge[i].nxt) {
        int v = edge[i].v;
        if(v==f||vis[v]) continue;
        GetRoot(v,u);
        siz[u] += siz[v];
        wt[u] = max(wt[u],siz[v]);
    }
    wt[u] = max(wt[u],Tsiz-siz[u]);
    if(wt[Root]>wt[u]) Root = u;
}

void dfs(int u,int f,int dis) {
    num[dis]++;
    len2 = max(len2,dis);
    for(int i = head[u]; ~i ; i = edge[i].nxt) {
        int v = edge[i].v;
        if(v==f||vis[v]) continue;
        dfs(v,u,dis+1);
    }
}

void calc(int u) {
    len1 = 1;  Tnum[1] = 1;
    for(int i = head[u]; ~i ; i = edge[i].nxt) {
        int v = edge[i].v;
        if(vis[v]) continue;

        len2 = 1;
        dfs(v,u,1);

        mul(Tnum,num,sum,len1,len2);
        rep(i, 1 ,len2) {
            Tnum[i+1] += num[i];
            num[i] = 0;
        }
        len1 = max(len1,len2+1);
    }
    rep(i, 0, len1) Tnum[i] = 0;
}
void divide(int u) {
    calc(u);
    vis[u] = 1;         //删掉该点
    for(int i = head[u]; ~i ; i = edge[i].nxt) {
        int v = edge[i].v;
        if(vis[v]) continue;
        Root = 0,Tsiz = siz[v];
        GetRoot(v,0);
        divide(Root);
    }
}
ll pow_mod(ll n,ll m) {
    ll ans = 1;
    while(m) {
        if(m&1) ans = ans*n%mod;
        n = n*n%mod;
        m >>= 1;
    }
    return ans;
}

int main() {
    //freopen("a.txt","r",stdin);
    //ios::sync_with_stdio(0);
    scanf("%d",&n);
    init(n);
    rep(i, 1, n-1) {
        int u,v;
        scanf("%d%d",&u,&v);
        ae(u,v);
        ae(v,u);
    }
    rep(i, 1, n) vis[i] = 0;
    wt[0] = 1e9,Tsiz = n,GetRoot(1,0);
    divide(Root);
    ll ans = 0;
    sum[1] = n;
    rep(i, 2, n) sum[i] = sum[i]*2%mod;
    rep(i, 1, n) ans = (ans +  sum[i] * pow_mod(i,mod-2))%mod;

    rep(i, 1, n) ans = ans*i%mod;
    printf("%lld",ans);
    return 0;
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章