Tree (树上期望dp)

Tree

11.2 和 11.3

11.2

考虑1个随机过程,第1次走到u号点的时间可以分成两部分,第1部分是从1号点随机游走第1次走到u的父亲p的时间,第2部分是从p开始走,第1次走到u的时间,由期望的线性性,第1次走到u的时间期望等于这两部分期望的和。第1部分是一个子问题,我们考虑怎么解决第2部分,我们把这个问题变成1棵树(并且根节点脑袋上也有1条边),从根节点开始随机游走,走出这棵树期望的时间,我们x[u]表示这个期望,我们对u的子树中的点也类似地定义x[v],这样我们可以列出关系式:

这里写图片描述

其中d是的u度数(包括那根天线),这个关系是中的第一个1表示直接向上.,后面那个扩号中的三部分,那个1表示从u走向v, x[v]表示从v回来期望时间, x[u]表示这个时候继续走,走出去还需要花的时间。因为是等概率,所以直接乘以1/d这个概率即可。化简后是:

这里写图片描述

即x[u]等于这棵子树的所有节点度的和,考虑到除了那根天线之外,所有的边对度的贡献为2,所以:

这里写图片描述

这样,子问题就有了一个简单的答案了。我们回到原问题,dp[u]表示第一次走到u的期望时间,p表示u的父亲,有:

这里写图片描述

完美解决了这个问题,复杂度O(n),其实答案都是整数,那三位小数也是来骗你的^_^。

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
#define LL long long
#define N 1000010
using namespace std;

int n, idc=0, idx=0;
int head[N], siz[N], fa[N], seq[N];
LL ans[N];

struct Edge{
    int to, nxt;
}ed[N];

void adde(int u, int v){
    ed[++idc].to = v;
    ed[idc].nxt = head[u];
    head[u] = idc;
}

void dfs(int u){
    seq[++idx] = u;
    siz[u] = 1;
    for(int i=head[u]; i; i=ed[i].nxt){
        int v = ed[i].to;
        if(v == fa[u]) continue;
        fa[v] = u;
        dfs( v ); 
        siz[u] += siz[v];
    }
}

void dfs2(int u, int fa){
    for(int i=head[u]; i; i=ed[i].nxt){
        int v = ed[i].to;
        if(v == fa) continue;
        ans[v] = ans[u] + 2 * (n - siz[v]) - 1;
        dfs2(v, u);
    }
}

int main(){
    freopen ("tree.in", "r", stdin);
    freopen ("tree.out", "w", stdout);
    scanf("%d", &n); ans[1] = 1;
    for(int i=1; i<n; i++){
        int u, v; scanf("%d%d", &u, &v); 
        adde(u, v); adde(v, u);
    }
    dfs( 1 );
    ans[1]=1;
    dfs2(1,1);
    for(int i=1; i<=n; i++)
        printf("%lld.000\n", ans[i]);
    return 0;
}

11.3

从u走到v一定是从u到fa[u], 在到fa[fa[u]], …, 再从u,v的lca一步一步走到v. 那么如果能够算出从u到fa[u]的期望步数(设为f[u])和从fa[u]到u的期望步数(设为g[u])就能够做了.

这里写图片描述

式子就是在枚举第一步怎么走来列方程. 化简发现f[u]和g[u]其实是整数. 先求f,再求g,
然后求lca就可以了.

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#define LL long long
#define N 200010
#define mod 1000000007
#define P 17
using namespace std;

int n, q, idc=0, idx=0;
int head[N], siz[N], fa[N], dep[N];
int pw[P+1], acc[N][P+1];
LL ans[N], sum[N];

inline int read(){
    int x = 0, f = 1; char ch = getchar();
    while(ch < '0' || ch > '9'){ if(ch == '-') f = -1; ch = getchar(); }
    while(ch >= '0' && ch <= '9'){ x = x * 10 + ch - '0'; ch = getchar(); }
    return x * f;
}

LL up(LL x){
    while (x >= mod) x -= mod;
    while (x < 0) x += mod;
    return x;
}

struct Edge{
    int to, nxt;
}ed[N<<1];

inline void adde(int u, int v){
    ed[++idc].to = v;
    ed[idc].nxt = head[u];
    head[u] = idc;
}

inline void dfs(int u, int f){
    siz[u] = 1;
    acc[u][0] = f; 
    for(int i=1; i<=P; i++) 
        acc[u][i] = acc[acc[u][i-1]][i-1];
    for(int i=head[u]; i; i=ed[i].nxt){
        int v = ed[i].to;
        if(v == fa[u]) continue;
        dep[v] = dep[u] + 1;
        fa[v] = u;
        dfs( v, u ); 
        siz[u] += siz[v];
    }
}

inline void dfs2(int u, int f){
    for(int i=head[u]; i; i=ed[i].nxt){
        int v = ed[i].to;
        if(v == f) continue;
        ans[v] = up(ans[u] + 2LL * (n - siz[v]) - 1LL);
        sum[v] = up(sum[u] + 2LL * siz[v] - 1LL);
        dfs2(v, u);
    }
}

inline int lca(int x, int y){
    if(dep[x] < dep[y]) swap(x, y);
    int t = dep[x] - dep[y];
    for(int i=0; pw[i]<=t; i++)
        if(t & pw[i]) x = acc[x][i];
    for(int i=P; i>=0; i--)
        if(acc[x][i] != acc[y][i])
            x = acc[x][i], y = acc[y][i];
    if(x == y) return x;
    return acc[x][0];
}

int main(){
    freopen ("tree.in", "r", stdin);
    freopen ("tree.out", "w", stdout);
    //cout << sizeof(pw) + sizeof(ans) + sizeof(sum) + sizeof(acc) + sizeof(head) + sizeof(siz) + sizeof(dep) + sizeof(fa) << endl;
    pw[0] = 1; for(int i=1; i<=P; i++) pw[i] = pw[i-1] << 1;
    scanf("%d%d", &n, &q);
    for(register int i=1; i<n; i++){
        int u = read(), v = read();
        adde(u, v); adde(v, u);
    }
    dfs( 1, 1 ); dfs2( 1, 1 );
    /*for(register int i=1; i<=n; i++)  printf("%d\n", ans[i]);
    for(register int i=1; i<=n; i++)    printf("%d\n", sum[i]);*/
    while ( q-- ){
        int u = read(), v = read();
        int LCA = lca(u, v);
        printf("%lld\n", up(ans[v] - ans[LCA] + sum[u] - sum[LCA]));
    }
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章