題目鏈接:https://atcoder.jp/contests/abc149/tasks/abc149_f
題目大意:
給定一個N個節點的樹,每個節點有1/2的概率染黑,有1/2的概率染白,隨機染色之後,在樹上找到一個最小的連通塊S,使得S包圍所有的黑色節點,求S包圍的白色節點的期望個數。對1e9+7取模。
比如,如下樣例:
3 1 2 2 3
只有一種情況S中包含白色節點,即1染黑色,2染白色,3染黑色,這樣的概率是1/8,所以答案爲
125000001
對(1e9+7)取模。
思路:
既然是每個點有一個概率,而期望是線性的,所以我們考慮每個點被包圍的概率p,這樣每個點對答案的期望就是1/2*p(因爲這個點必須染白色)。
一個點如果必須被S覆蓋,說明與他相連的幾個連通塊中,至少有兩個連通塊包含黑色節點,那麼兩個連通塊如果需要合併的話,必須繞過這個點,此時這個點對答案有貢獻。
所以我們的任務就是對於每個點,求出與他相連的幾個連通塊,”至少有兩個連通塊包含黑點“的概率。
怎麼算呢?
如下圖:
三個連通塊大小如圖。我們先求出每個連通塊包含黑點的概率pi(這個就很好求了,1減去連通塊都是白點的情況),然後反過來考慮,至多隻有一個連通塊包含黑點的概率=全是百點的概率+有一個連通塊包含黑點的概率。
比如圖中至多隻有一個連通塊包含黑點的概率=
然後用1減去這個概率就是我們要求的了。
這樣基本上的思路就完成了。細節看代碼:
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=2e5+10;
const int mod=1e9+7;
int pow_mod(int a,int n,int mod){
if(n==0)return 1;
int ans=pow_mod(a,n/2,mod);
ans=ans*ans%mod;
if(n%2==1)ans=ans*a%mod;
return ans;
}
//擴展GCD
int ex_gcd(int a,int b,int &x,int &y){
if(b==0){x = 1;y = 0;return a;}
int g = ex_gcd(b,a%b,x,y);
int temp = x;
x = y;
y = temp - a/b*y;
return g;
}
//擴展歐幾里得的另一種寫法
void ex_gcd(int a,int b,int &d,int &x,int &y){
//求解ax+by=gcd(a,b)的一組解
if(!b){
d=a,x=1,y=0;
}
else{
ex_gcd(b,a%b,d,y,x);
y-=x*(a/b);
}
}
//逆元
int inv(int a,int mod){
int X,Y;
int g = ex_gcd(a,mod,X,Y);
if(g!=1)return -1;
return (X%mod + mod)%mod;
}
vector<int>G[maxn];
int sz[maxn];
int fa[maxn];
void dfs(int u,int f){
sz[u]=1;
fa[u]=f;
for(int i=0;i<G[u].size();i++){
int v=G[u][i];
if(v==f)continue;
dfs(v,u);
sz[u]+=sz[v];
}
}
int p[maxn];
signed main(){
int n;
scanf("%lld",&n);
for(int i=1;i<n;i++){
int u,v;
scanf("%lld%lld",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1,-1);
int ans=0;
int inv2=inv(2,mod);
for(int i=1;i<=n;i++){
int tot=G[i].size();
if(G[i].size()==1)continue;
int now=1;
for(int j=0;j<tot;j++){
int v=G[i][j];
int siz=0;
if(v==fa[i])siz=n-sz[i];
else siz=sz[v];
p[j]=(1-inv(pow_mod(2,siz,mod),mod)+mod)%mod;//點黑的概率
now*=((1-p[j]+mod)%mod);
now%=mod;
}
int cur=now;
for(int j=0;j<tot;j++){
cur+=now*inv((1-p[j]+mod)%mod,mod)%mod*p[j]%mod;
cur%=mod;
}
ans+=inv2*(1-cur+mod)%mod;
ans%=mod;
}
cout<<ans<<endl;
return 0;
}