Forest Game 【期望】【點分治】【FFT】

題目鏈接:https://vjudge.net/problem/Gym-101234D

題目大意:給一顆樹,共N個點,每次隨機選擇一個點,得分加上該點所在樹的大小,然後刪除這個點,斷開與其相連的所有邊,問刪完所有點所獲得的   期望得分乘N! 是多少。

emmm被這個乘N! 給坑了,以爲這是個假的期望題目,因爲全排列共N!種, 期望乘N! 其實就是全排列的和,覺得這是出題人留的解題思路.....然後就去考慮和的問題了,從此走上了不歸路。

 

看了題解後發現,這個N! 一點用都沒有啊啊啊啊啊啊啊。

重新用期望的方式開始思考,我們考慮節點u對節點v的貢獻,u對v有貢獻當且僅當u是u->v這條鏈上第一個刪除的點,概率爲1/len(len是該鏈上點的個數,包含端點)。

然後就變成了求樹上每種長度(長度定義爲路徑上點數)的路徑條數的模板,除了長度爲1的以外,都需要算2次。

點分治的過程中使用fft合併就可以了。

複雜度O (N * logN * logN)

#include <bits/stdc++.h>
#define ll long long
#define rep(i, a, b) for(int i = (a); i <= (b); i++)
#define pb push_back
#define ll long long
using namespace std;
const int M = 1e5+1000;
const int mod = 1e9+7;
const double PI = acos(-1.0);
struct Complex  {
    double x, y;
    Complex(double _x=0.0,double _y = 0.0) {
        x = _x;
        y = _y;
    }
    Complex operator -(const Complex &b) const {
        return Complex(x-b.x, y-b.y);
    }
    Complex operator +(const Complex &b) const {
        return Complex(x+b.x, y+b.y);
    }
    Complex operator *(const Complex &b) const {
        return Complex(x*b.x-y*b.y, x*b.y+y*b.x);
    }
};
void change(Complex y[], int len) {
    int i,j,k;
    for(i = 1, j = len/2; i < len-1; i++) {
        if(i < j) swap(y[i],y[j]);
        k = len/2;
        while(j >= k) {
            j -= k;
            k /= 2;
        }
        if(j < k) j += k;
    }
}
void fft(Complex y[], int len, int on) {
    change(y, len);
    for(int h = 2; h <= len; h <<= 1) {
        Complex wn(cos(-on*2*PI/h),sin(-on*2*PI/h));
        for(int j = 0; j <= len; j+= h) {
            Complex w(1,0);
            for(int k = j; k < j+h/2; k++) {
                Complex u = y[k];
                Complex t = w*y[k+h/2];
                y[k] = u+t;
                y[k+h/2] = u-t;
                w = w*wn;

            }
        }
    }
    if(on == -1)
        for(int i = 0; i < len; i++)
            y[i].x /= len;
}
const int N = 800010; //表達式長度的八倍
Complex x1[N],x2[N];
ll sum[N],Tnum[N],num[N];
void mul(ll a[],ll b[],ll sum[],int n,int m) {
    int len1 = n+1;
    int len2 = m+1;
    int len = 1;
    while(len < len1*2 || len < len2*2) len <<= 1;
    for(int i = 0; i < len1; i++) x1[i] = Complex(a[i],0);
    for(int i = len1; i < len; i++) x1[i] = Complex(0,0);
    for(int i = 0; i < len2; i++) x2[i] = Complex(b[i],0);
    for(int i = len2; i < len; i++) x2[i] = Complex(0,0);
    fft(x1,len,1);
    fft(x2,len,1);
    for(int i = 0; i < len; i++) x1[i] = x1[i]*x2[i];
    fft(x1,len,-1);
    for(int i = 0; i < len1+len2-1; i++) sum[i] += (ll)(x1[i].x+0.5);
}
//--------------------------------------------------


int n;
struct node {
    int v,nxt;
}edge[2*M];
int tot,head[M];
void ae(int u,int v) {
    edge[++tot] = node{v,head[u]};
    head[u] = tot;
}
void init(int n) {
    tot = 0;
    rep(i, 1, n) head[i] = -1;
}
int siz[M],Root,wt[M],Tsiz,len1,len2;
bool vis[M];

void GetRoot(int u,int f) {
    siz[u] = 1;
    wt[u] = 0;
    for(int i = head[u]; ~i ; i = edge[i].nxt) {
        int v = edge[i].v;
        if(v==f||vis[v]) continue;
        GetRoot(v,u);
        siz[u] += siz[v];
        wt[u] = max(wt[u],siz[v]);
    }
    wt[u] = max(wt[u],Tsiz-siz[u]);
    if(wt[Root]>wt[u]) Root = u;
}

void dfs(int u,int f,int dis) {
    num[dis]++;
    len2 = max(len2,dis);
    for(int i = head[u]; ~i ; i = edge[i].nxt) {
        int v = edge[i].v;
        if(v==f||vis[v]) continue;
        dfs(v,u,dis+1);
    }
}

void calc(int u) {
    len1 = 1;  Tnum[1] = 1;
    for(int i = head[u]; ~i ; i = edge[i].nxt) {
        int v = edge[i].v;
        if(vis[v]) continue;

        len2 = 1;
        dfs(v,u,1);

        mul(Tnum,num,sum,len1,len2);
        rep(i, 1 ,len2) {
            Tnum[i+1] += num[i];
            num[i] = 0;
        }
        len1 = max(len1,len2+1);
    }
    rep(i, 0, len1) Tnum[i] = 0;
}
void divide(int u) {
    calc(u);
    vis[u] = 1;         //刪掉該點
    for(int i = head[u]; ~i ; i = edge[i].nxt) {
        int v = edge[i].v;
        if(vis[v]) continue;
        Root = 0,Tsiz = siz[v];
        GetRoot(v,0);
        divide(Root);
    }
}
ll pow_mod(ll n,ll m) {
    ll ans = 1;
    while(m) {
        if(m&1) ans = ans*n%mod;
        n = n*n%mod;
        m >>= 1;
    }
    return ans;
}

int main() {
    //freopen("a.txt","r",stdin);
    //ios::sync_with_stdio(0);
    scanf("%d",&n);
    init(n);
    rep(i, 1, n-1) {
        int u,v;
        scanf("%d%d",&u,&v);
        ae(u,v);
        ae(v,u);
    }
    rep(i, 1, n) vis[i] = 0;
    wt[0] = 1e9,Tsiz = n,GetRoot(1,0);
    divide(Root);
    ll ans = 0;
    sum[1] = n;
    rep(i, 2, n) sum[i] = sum[i]*2%mod;
    rep(i, 1, n) ans = (ans +  sum[i] * pow_mod(i,mod-2))%mod;

    rep(i, 1, n) ans = ans*i%mod;
    printf("%lld",ans);
    return 0;
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章