Tree(树形DP)

Tree(树形DP)

不太熟,记录一下思路

传送门

思路:对于树上某一个点的连通点集的数量,我们可以以它为根进行树上dpdp

根据乘法原理:dp[u]=dp[u]×vson(dp[v]+1)dp[u]=dp[u]\times\prod\limits_{v\in son}(dp[v]+1)

举个例子:

aa为根,显然dp[b]=dp[c]=1,dp[b]=dp[c]=1,通过边(a,b)(a,b)产生的贡献是dp[b]+1dp[b]+1。这个加1表示不选bb的子集,然后每个dp[b]+1dp[b]+1又会对应dp[c]+1dp[c]+1种情况。

因此可以用乘法原理求得。

但是此题要求我们求出所有结点的连通点集数量,显然一个个dfsdfs会超时。

所以我们考虑由已知的dp[fa]dp[fa],进行状态转移。

显然若uufafadp[fa]dp[fa]已知,当我们以uu为根结点求答案时,我们只需用dp[u]×(udp[u]\times (u的儿子fafa的贡献),因为dp[fa]dp[fa]存在重复计算dp[u]dp[u]的情况,所以我们应该先把dp[fa]dp[fa]除以dp[u]+1dp[u]+1还原,然后dp[fa]dp[u]+1+1\dfrac{dp[fa]}{dp[u]+1}+1即是fafa的贡献。

综上:状态转移方程为:dp[u]=dp[u]×(dp[fa]dp[u]+1+1)dp[u]=dp[u]\times(\dfrac{dp[fa]}{dp[u]+1}+1)

另外由于此题数据很大,需要取模,而且我们进行状态转移的时候存在除法的情况。

显然我们需要对dp[u]+1=0dp[u]+1=0的情况进行特判。

我们需要再开一个数组来维护特判00之后的dp[fa]dp[u]+1\dfrac{dp[fa]}{dp[u]+1},我们记为dp1[fa]dp1[fa]

然后我们再以fafa为之前的根结点进行状态转移,注意不能算ufa[fa]u和fa[fa]

因为fa[fa]fa[fa]已经转移过了,即(dp1[fa]+1)dp1[fa]+1)包含了dp[fa[fa]]dp[fa[fa]]的情况,就相当于隔断了,以fafa原来的根转移。

时间复杂度:O(2nlogn)O(2nlogn)

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+5,M=1e6+5,inf=0x3f3f3f3f,mod=1e9+7;
#define mst(a) memset(a,0,sizeof a)
#define lx x<<1
#define rx x<<1|1
#define reg register
#define PII pair<int,int>
#define fi first 
#define se second
ll dp[N],dp1[N],ans[N];
int n,Fa[N];
vector<int>e[N];
inline void read(int &x){ 
	x=0;int w=1;
	char ch=getchar();
	while(ch<'0'||ch>'9') {if(ch=='-') w=-1;ch=getchar();}
	for(;ch>='0'&&ch<='9';ch=getchar())
		x=(x<<3)+(x<<1)+(ch&15);
	x*=w; 
}
ll ksm(ll a,ll n){
	ll ans=1;
	while(n){
		if(n&1) ans=ans*a%mod;
		a=a*a%mod;
		n>>=1;
	}
	return ans;
}
void dfs(int u,int fa){
	dp[u]=1,Fa[u]=fa;
	for(auto v:e[u]){
		if(v==fa) continue;
		dfs(v,u);
		dp[u]=dp[u]*(dp[v]+1)%mod;
	}
}
void dfs1(int u,int fa){
	if(!fa) ans[u]=dp[u];
	else {
		ll tmp;
		if((dp[u]+1)%mod==0){
			 tmp=dp1[fa]+1; 
			for(auto v:e[fa]){
				if(v==u||v==Fa[fa]) continue; 
				tmp=tmp*(dp[v]+1)%mod;
			}
			//printf("%lld\n",tmp);
			dp1[u]=tmp;
			ans[u]=dp[u]*(tmp+1)%mod; 
		}
		else tmp=ans[fa]*ksm(dp[u]+1,mod-2)%mod,ans[u]=dp[u]*(tmp+1)%mod;
	}
	for(auto v:e[u]){
		if(v==fa) continue;
		dfs1(v,u);
	}
} 
int main(){
	int  n;
	read(n);
	for(reg int i=1;i<n;i++){
		int u,v;
		read(u),read(v);
		e[u].push_back(v),e[v].push_back(u);
	}
	dfs(1,0);
	dfs1(1,0);
	for(int i=1;i<=n;i++) printf("%lld\n",ans[i]);
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章