Tree(樹形DP)

Tree(樹形DP)

不太熟,記錄一下思路

傳送門

思路:對於樹上某一個點的連通點集的數量,我們可以以它爲根進行樹上dpdp

根據乘法原理:dp[u]=dp[u]×vson(dp[v]+1)dp[u]=dp[u]\times\prod\limits_{v\in son}(dp[v]+1)

舉個例子:

aa爲根,顯然dp[b]=dp[c]=1,dp[b]=dp[c]=1,通過邊(a,b)(a,b)產生的貢獻是dp[b]+1dp[b]+1。這個加1表示不選bb的子集,然後每個dp[b]+1dp[b]+1又會對應dp[c]+1dp[c]+1種情況。

因此可以用乘法原理求得。

但是此題要求我們求出所有結點的連通點集數量,顯然一個個dfsdfs會超時。

所以我們考慮由已知的dp[fa]dp[fa],進行狀態轉移。

顯然若uufafadp[fa]dp[fa]已知,當我們以uu爲根結點求答案時,我們只需用dp[u]×(udp[u]\times (u的兒子fafa的貢獻),因爲dp[fa]dp[fa]存在重複計算dp[u]dp[u]的情況,所以我們應該先把dp[fa]dp[fa]除以dp[u]+1dp[u]+1還原,然後dp[fa]dp[u]+1+1\dfrac{dp[fa]}{dp[u]+1}+1即是fafa的貢獻。

綜上:狀態轉移方程爲:dp[u]=dp[u]×(dp[fa]dp[u]+1+1)dp[u]=dp[u]\times(\dfrac{dp[fa]}{dp[u]+1}+1)

另外由於此題數據很大,需要取模,而且我們進行狀態轉移的時候存在除法的情況。

顯然我們需要對dp[u]+1=0dp[u]+1=0的情況進行特判。

我們需要再開一個數組來維護特判00之後的dp[fa]dp[u]+1\dfrac{dp[fa]}{dp[u]+1},我們記爲dp1[fa]dp1[fa]

然後我們再以fafa爲之前的根結點進行狀態轉移,注意不能算ufa[fa]u和fa[fa]

因爲fa[fa]fa[fa]已經轉移過了,即(dp1[fa]+1)dp1[fa]+1)包含了dp[fa[fa]]dp[fa[fa]]的情況,就相當於隔斷了,以fafa原來的根轉移。

時間複雜度:O(2nlogn)O(2nlogn)

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+5,M=1e6+5,inf=0x3f3f3f3f,mod=1e9+7;
#define mst(a) memset(a,0,sizeof a)
#define lx x<<1
#define rx x<<1|1
#define reg register
#define PII pair<int,int>
#define fi first 
#define se second
ll dp[N],dp1[N],ans[N];
int n,Fa[N];
vector<int>e[N];
inline void read(int &x){ 
	x=0;int w=1;
	char ch=getchar();
	while(ch<'0'||ch>'9') {if(ch=='-') w=-1;ch=getchar();}
	for(;ch>='0'&&ch<='9';ch=getchar())
		x=(x<<3)+(x<<1)+(ch&15);
	x*=w; 
}
ll ksm(ll a,ll n){
	ll ans=1;
	while(n){
		if(n&1) ans=ans*a%mod;
		a=a*a%mod;
		n>>=1;
	}
	return ans;
}
void dfs(int u,int fa){
	dp[u]=1,Fa[u]=fa;
	for(auto v:e[u]){
		if(v==fa) continue;
		dfs(v,u);
		dp[u]=dp[u]*(dp[v]+1)%mod;
	}
}
void dfs1(int u,int fa){
	if(!fa) ans[u]=dp[u];
	else {
		ll tmp;
		if((dp[u]+1)%mod==0){
			 tmp=dp1[fa]+1; 
			for(auto v:e[fa]){
				if(v==u||v==Fa[fa]) continue; 
				tmp=tmp*(dp[v]+1)%mod;
			}
			//printf("%lld\n",tmp);
			dp1[u]=tmp;
			ans[u]=dp[u]*(tmp+1)%mod; 
		}
		else tmp=ans[fa]*ksm(dp[u]+1,mod-2)%mod,ans[u]=dp[u]*(tmp+1)%mod;
	}
	for(auto v:e[u]){
		if(v==fa) continue;
		dfs1(v,u);
	}
} 
int main(){
	int  n;
	read(n);
	for(reg int i=1;i<n;i++){
		int u,v;
		read(u),read(v);
		e[u].push_back(v),e[v].push_back(u);
	}
	dfs(1,0);
	dfs1(1,0);
	for(int i=1;i<=n;i++) printf("%lld\n",ans[i]);
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章