题目链接:https://atcoder.jp/contests/abc149/tasks/abc149_f
题目大意:
给定一个N个节点的树,每个节点有1/2的概率染黑,有1/2的概率染白,随机染色之后,在树上找到一个最小的连通块S,使得S包围所有的黑色节点,求S包围的白色节点的期望个数。对1e9+7取模。
比如,如下样例:
3 1 2 2 3
只有一种情况S中包含白色节点,即1染黑色,2染白色,3染黑色,这样的概率是1/8,所以答案为
125000001
对(1e9+7)取模。
思路:
既然是每个点有一个概率,而期望是线性的,所以我们考虑每个点被包围的概率p,这样每个点对答案的期望就是1/2*p(因为这个点必须染白色)。
一个点如果必须被S覆盖,说明与他相连的几个连通块中,至少有两个连通块包含黑色节点,那么两个连通块如果需要合并的话,必须绕过这个点,此时这个点对答案有贡献。
所以我们的任务就是对于每个点,求出与他相连的几个连通块,”至少有两个连通块包含黑点“的概率。
怎么算呢?
如下图:
三个连通块大小如图。我们先求出每个连通块包含黑点的概率pi(这个就很好求了,1减去连通块都是白点的情况),然后反过来考虑,至多只有一个连通块包含黑点的概率=全是百点的概率+有一个连通块包含黑点的概率。
比如图中至多只有一个连通块包含黑点的概率=
然后用1减去这个概率就是我们要求的了。
这样基本上的思路就完成了。细节看代码:
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=2e5+10;
const int mod=1e9+7;
int pow_mod(int a,int n,int mod){
if(n==0)return 1;
int ans=pow_mod(a,n/2,mod);
ans=ans*ans%mod;
if(n%2==1)ans=ans*a%mod;
return ans;
}
//扩展GCD
int ex_gcd(int a,int b,int &x,int &y){
if(b==0){x = 1;y = 0;return a;}
int g = ex_gcd(b,a%b,x,y);
int temp = x;
x = y;
y = temp - a/b*y;
return g;
}
//扩展欧几里得的另一种写法
void ex_gcd(int a,int b,int &d,int &x,int &y){
//求解ax+by=gcd(a,b)的一组解
if(!b){
d=a,x=1,y=0;
}
else{
ex_gcd(b,a%b,d,y,x);
y-=x*(a/b);
}
}
//逆元
int inv(int a,int mod){
int X,Y;
int g = ex_gcd(a,mod,X,Y);
if(g!=1)return -1;
return (X%mod + mod)%mod;
}
vector<int>G[maxn];
int sz[maxn];
int fa[maxn];
void dfs(int u,int f){
sz[u]=1;
fa[u]=f;
for(int i=0;i<G[u].size();i++){
int v=G[u][i];
if(v==f)continue;
dfs(v,u);
sz[u]+=sz[v];
}
}
int p[maxn];
signed main(){
int n;
scanf("%lld",&n);
for(int i=1;i<n;i++){
int u,v;
scanf("%lld%lld",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1,-1);
int ans=0;
int inv2=inv(2,mod);
for(int i=1;i<=n;i++){
int tot=G[i].size();
if(G[i].size()==1)continue;
int now=1;
for(int j=0;j<tot;j++){
int v=G[i][j];
int siz=0;
if(v==fa[i])siz=n-sz[i];
else siz=sz[v];
p[j]=(1-inv(pow_mod(2,siz,mod),mod)+mod)%mod;//点黑的概率
now*=((1-p[j]+mod)%mod);
now%=mod;
}
int cur=now;
for(int j=0;j<tot;j++){
cur+=now*inv((1-p[j]+mod)%mod,mod)%mod*p[j]%mod;
cur%=mod;
}
ans+=inv2*(1-cur+mod)%mod;
ans%=mod;
}
cout<<ans<<endl;
return 0;
}