Tree (樹上期望dp)

Tree

11.2 和 11.3

11.2

考慮1個隨機過程,第1次走到u號點的時間可以分成兩部分,第1部分是從1號點隨機遊走第1次走到u的父親p的時間,第2部分是從p開始走,第1次走到u的時間,由期望的線性性,第1次走到u的時間期望等於這兩部分期望的和。第1部分是一個子問題,我們考慮怎麼解決第2部分,我們把這個問題變成1棵樹(並且根節點腦袋上也有1條邊),從根節點開始隨機遊走,走出這棵樹期望的時間,我們x[u]表示這個期望,我們對u的子樹中的點也類似地定義x[v],這樣我們可以列出關係式:

這裏寫圖片描述

其中d是的u度數(包括那根天線),這個關係是中的第一個1表示直接向上.,後面那個擴號中的三部分,那個1表示從u走向v, x[v]表示從v回來期望時間, x[u]表示這個時候繼續走,走出去還需要花的時間。因爲是等概率,所以直接乘以1/d這個概率即可。化簡後是:

這裏寫圖片描述

即x[u]等於這棵子樹的所有節點度的和,考慮到除了那根天線之外,所有的邊對度的貢獻爲2,所以:

這裏寫圖片描述

這樣,子問題就有了一個簡單的答案了。我們回到原問題,dp[u]表示第一次走到u的期望時間,p表示u的父親,有:

這裏寫圖片描述

完美解決了這個問題,複雜度O(n),其實答案都是整數,那三位小數也是來騙你的^_^。

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
#define LL long long
#define N 1000010
using namespace std;

int n, idc=0, idx=0;
int head[N], siz[N], fa[N], seq[N];
LL ans[N];

struct Edge{
    int to, nxt;
}ed[N];

void adde(int u, int v){
    ed[++idc].to = v;
    ed[idc].nxt = head[u];
    head[u] = idc;
}

void dfs(int u){
    seq[++idx] = u;
    siz[u] = 1;
    for(int i=head[u]; i; i=ed[i].nxt){
        int v = ed[i].to;
        if(v == fa[u]) continue;
        fa[v] = u;
        dfs( v ); 
        siz[u] += siz[v];
    }
}

void dfs2(int u, int fa){
    for(int i=head[u]; i; i=ed[i].nxt){
        int v = ed[i].to;
        if(v == fa) continue;
        ans[v] = ans[u] + 2 * (n - siz[v]) - 1;
        dfs2(v, u);
    }
}

int main(){
    freopen ("tree.in", "r", stdin);
    freopen ("tree.out", "w", stdout);
    scanf("%d", &n); ans[1] = 1;
    for(int i=1; i<n; i++){
        int u, v; scanf("%d%d", &u, &v); 
        adde(u, v); adde(v, u);
    }
    dfs( 1 );
    ans[1]=1;
    dfs2(1,1);
    for(int i=1; i<=n; i++)
        printf("%lld.000\n", ans[i]);
    return 0;
}

11.3

從u走到v一定是從u到fa[u], 在到fa[fa[u]], …, 再從u,v的lca一步一步走到v. 那麼如果能夠算出從u到fa[u]的期望步數(設爲f[u])和從fa[u]到u的期望步數(設爲g[u])就能夠做了.

這裏寫圖片描述

式子就是在枚舉第一步怎麼走來列方程. 化簡發現f[u]和g[u]其實是整數. 先求f,再求g,
然後求lca就可以了.

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#define LL long long
#define N 200010
#define mod 1000000007
#define P 17
using namespace std;

int n, q, idc=0, idx=0;
int head[N], siz[N], fa[N], dep[N];
int pw[P+1], acc[N][P+1];
LL ans[N], sum[N];

inline int read(){
    int x = 0, f = 1; char ch = getchar();
    while(ch < '0' || ch > '9'){ if(ch == '-') f = -1; ch = getchar(); }
    while(ch >= '0' && ch <= '9'){ x = x * 10 + ch - '0'; ch = getchar(); }
    return x * f;
}

LL up(LL x){
    while (x >= mod) x -= mod;
    while (x < 0) x += mod;
    return x;
}

struct Edge{
    int to, nxt;
}ed[N<<1];

inline void adde(int u, int v){
    ed[++idc].to = v;
    ed[idc].nxt = head[u];
    head[u] = idc;
}

inline void dfs(int u, int f){
    siz[u] = 1;
    acc[u][0] = f; 
    for(int i=1; i<=P; i++) 
        acc[u][i] = acc[acc[u][i-1]][i-1];
    for(int i=head[u]; i; i=ed[i].nxt){
        int v = ed[i].to;
        if(v == fa[u]) continue;
        dep[v] = dep[u] + 1;
        fa[v] = u;
        dfs( v, u ); 
        siz[u] += siz[v];
    }
}

inline void dfs2(int u, int f){
    for(int i=head[u]; i; i=ed[i].nxt){
        int v = ed[i].to;
        if(v == f) continue;
        ans[v] = up(ans[u] + 2LL * (n - siz[v]) - 1LL);
        sum[v] = up(sum[u] + 2LL * siz[v] - 1LL);
        dfs2(v, u);
    }
}

inline int lca(int x, int y){
    if(dep[x] < dep[y]) swap(x, y);
    int t = dep[x] - dep[y];
    for(int i=0; pw[i]<=t; i++)
        if(t & pw[i]) x = acc[x][i];
    for(int i=P; i>=0; i--)
        if(acc[x][i] != acc[y][i])
            x = acc[x][i], y = acc[y][i];
    if(x == y) return x;
    return acc[x][0];
}

int main(){
    freopen ("tree.in", "r", stdin);
    freopen ("tree.out", "w", stdout);
    //cout << sizeof(pw) + sizeof(ans) + sizeof(sum) + sizeof(acc) + sizeof(head) + sizeof(siz) + sizeof(dep) + sizeof(fa) << endl;
    pw[0] = 1; for(int i=1; i<=P; i++) pw[i] = pw[i-1] << 1;
    scanf("%d%d", &n, &q);
    for(register int i=1; i<n; i++){
        int u = read(), v = read();
        adde(u, v); adde(v, u);
    }
    dfs( 1, 1 ); dfs2( 1, 1 );
    /*for(register int i=1; i<=n; i++)  printf("%d\n", ans[i]);
    for(register int i=1; i<=n; i++)    printf("%d\n", sum[i]);*/
    while ( q-- ){
        int u = read(), v = read();
        int LCA = lca(u, v);
        printf("%lld\n", up(ans[v] - ans[LCA] + sum[u] - sum[LCA]));
    }
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章