參考博客
http://blog.csdn.net/firenet1/article/details/47445921
http://blog.csdn.net/pibaixinghei/article/details/52783432
有兩種方法,一種是計數DP,另一種是概率DP。
計數DP:
應該都能想到dp[i][j]表示以i爲根的子樹,有j個領導。接下來考慮狀態轉移。
自己一開始考慮枚舉分配方案,就是對dp[i][j],枚舉j個領導如何分配給兒子節點,但是這樣時間複雜度肯定是不能接受的。
事實上這樣的枚舉浪費了組合數公式,我們可以考慮將其中組合方面的枚舉提取成公式計算出來。
降低時間複雜度的方法有很多,可以考慮交換枚舉順序,改變枚舉量,動態改變枚舉範圍,維護一些值等技巧來降低時間複雜度。
由於各個子樹的方案相互獨立,因此我們可以用乘法原理,逐個考慮每個兒子,用組合數公式對每一個兒子的所有可能組合計算完,然後拋棄之。
時間複雜度O(n^2),稍加改造就可以O(nk)
概率DP:
我們也可以計算出概率,然後乘以總次數。
概率一般都是小數或者分數,如果用double肯定會有精度問題,如果用分數一定會爆long long。
事實上我們可以用逆元來解決這個問題。
逆元可以解決過程是小數,但是初始和結果都是整數的問題。
時間複雜度O(nk)
計數DP代碼
#include<stdio.h>
#include<vector>
using namespace std;
const int maxn = 1010;
const int mod = 1000000007;
int n,k;
vector<int>G[maxn];
int dp[maxn][maxn];
int siz[maxn];
int C[maxn][maxn];
int tp[maxn];
void read()
{
scanf("%d %d",&n,&k);
for(int i=1;i<=n;i++) G[i].clear();
int u,v;
for(int i=1;i<n;i++)
{
scanf("%d %d",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
}
void dfs1(int u,int f)
{
siz[u]=1;
for(int i=0;i<(int)G[u].size();i++)
{
int v = G[u][i];
if(v==f) continue;
dfs1(v,u);
siz[u]+=siz[v];
}
}
void dfs2(int u,int f)
{
int e = 1;
dp[u][1]=1;
dp[u][0]=siz[u]-1;
for(int ii=0;ii<(int)G[u].size();ii++)
{
int v = G[u][ii];
if(v==f) continue;
dfs2(v,u);
for(int i=0;i<=e+siz[v];i++)
tp[i]=0;
for(int i=0;i<=siz[v];i++)
dp[v][i]=1ll*dp[v][i]*C[siz[u]-e][siz[v]]%mod;
for(int i=0;i<=siz[v];i++)
for(int j=0;j<=e;j++)
tp[i+j]=(tp[i+j]+1ll*dp[u][j]*dp[v][i])%mod;
e+=siz[v];
for(int i=0;i<=e;i++)
dp[u][i]=tp[i];
}
}
void solve()
{
read();
dfs1(1,0);
dfs2(1,0);
printf("%d\n",dp[1][k]);
}
void init()
{
C[0][0]=1;
for(int i=1;i<maxn;i++)
{
C[i][0]=1;
for(int j=1;j<=i;j++)
C[i][j]=(1ll*C[i-1][j]+C[i-1][j-1])%mod;
}
}
int main()
{
init();
int T;
scanf("%d",&T);
for(int t=1;t<=T;t++)
{
printf("Case #%d: ",t);
solve();
}
return 0;
}
概率DP代碼
#include<stdio.h>
#include<vector>
using namespace std;
const int mod = 1000000007;
const int maxn = 1010;
vector<int>G[maxn];
int dp[maxn][maxn];
int mp(int x,int n)
{
int ret=1;
while(n)
{
if(n&1) ret=1ll*ret*x%mod;
x=1ll*x*x%mod;
n>>=1;
}
return ret;
}
int inv[maxn];
int siz[maxn];
int fac[maxn];
int n,k;
void read()
{
scanf("%d %d",&n,&k);
for(int i=1;i<=n;i++) G[i].clear();
int u,v;
for(int i=1;i<n;i++)
{
scanf("%d %d",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
}
void dfs(int u,int f)
{
siz[u]=1;
for(int i=0;i<(int)G[u].size();i++)
{
int v = G[u][i];
if(v==f) continue;
dfs(v,u);
siz[u]+=siz[v];
}
}
void solve()
{
read();
dfs(1,0);
dp[0][0]=1;
for(int i=1;i<=n;i++)
{
dp[i][0]=1ll*dp[i-1][0]*(siz[i]-1)%mod*inv[siz[i]]%mod;
for(int j=1;j<=k;j++)
dp[i][j]=(1ll*dp[i-1][j]*(siz[i]-1)%mod*inv[siz[i]]%mod+1ll*dp[i-1][j-1]*inv[siz[i]]%mod)%mod;
}
printf("%d\n",int(1ll*dp[n][k]*fac[n]%mod));
}
void init()
{
fac[0]=1;
for(int i=1;i<maxn;i++)
{
inv[i]=mp(i,mod-2);
fac[i]=1ll*fac[i-1]*i%mod;
}
}
int main()
{
init();
int T;
scanf("%d",&T);
for(int t=1;t<=T;t++)
{
printf("Case #%d: ",t);
solve();
}
return 0;
}