洛谷4895 BZOJ3162 獨釣寒江雪 樹形dp 樹哈希

題目鏈接

題意:
給定一棵無根樹,求其中本質不同的獨立集的個數。獨立集就是一個集合中的點之間都沒有邊直接相連。n<=5e5n<=5e5,對1e9+71e9+7取模。

題解:
首先膜拜一下y_immortaly\_immortal神仙,是這個神仙教的我這個題怎麼做QwQ.

首先考慮沒有本質不同應該怎麼算。我們設dp[x][0]dp[x][0]表示考慮xx爲根的子樹內不選xx這個點的方案數,設dp[x][1]dp[x][1]表示考慮xx爲根的子樹內選xx這個點的方案數。我們枚舉xx的每個子樹,我們用子樹的方案數乘起來就是答案。dp[x][0]=yson[x](dp[y][0]+dp[y][1])dp[x][0]=\prod_{y\in son[x]}(dp[y][0]+dp[y][1]),dp[x][1]=yson[x]dp[y][0]dp[x][1]=\prod_{y\in son[x]}dp[y][0]

於是考慮有本質不同怎麼來算。我們設dp[x][0]dp[x][0]表示考慮xx爲根的子樹內不選xx這個點本質不同的獨立集數,設dp[x][1]dp[x][1]表示考慮xx爲根的子樹選xx這個點的本質不同的獨立集數。我們把本質相同的樹放在一起考慮,我們假設現在考慮到的這種本質相同的子樹在xx的子樹中有kk棵,這種子樹的根設爲yy節點。爲了保證算方案的時候不會重複,我們給所有yy中可以的方案編一個號,也就是說yy子樹中有多少種方案,最大的一個編號就是多少。之後我們爲了不重複,所有相同的這些子樹,規定前面的子樹選的編號要小於等於後面的,這樣就可以不重不漏。而這個東西應該是一個可重複的組合數。於是我們把每一類本質相同的yy放在一起算,有dp[x][0]=yson[x](dp[y][1]+dp[y][0])Cdp[y][1]+dp[y][0]+k1kdp[x][0]=\prod_{y\in son[x]}(dp[y][1]+dp[y][0])*C_{dp[y][1]+dp[y][0]+k-1}^{k} , dp[x][1]=yson[x]dp[y][0]Cdp[y][0]+k1kdp[x][1]=\prod_{y\in son[x]}dp[y][0]*C_{dp[y][0]+k-1}^{k}

那麼下面的問題是如何判斷兩棵樹是否本質相同。我們要做的是樹哈希。這裏提供一種樹哈希的方法。我們設葉子節點的哈希值是11,然後對於其他點,我們把他們的子樹按照哈希值排序,然後依次乘進去,乘的時候像字符串哈希那樣,把每一個子樹看作一個字符,乘一個底數的多少次冪再加進來。最後再乘一個這個子樹的size。反正這樣起碼能保證相同的不會判成不同。

另外就是求組合數的時候,沒法求1e9+71e9+7那麼大的,不過我們發現,大部分項可以被約掉,就剩下mm項,我們暴力算這mm項就好,我們每個點最多被在組合數裏算一次,所以最終還是線性的。

複雜度O(n)O(n)

代碼:

#include <bits/stdc++.h>
using namespace std;

int n,hed[500010],cnt,sz[500010],mx[500010],rt,rt1,rt2,fa[500010];
long long ans,jie[500010],ni[500010],mi[500010],ha[500010],dp[500010][2];
const long long mod=1e9+7,base=23333;
struct node
{
	int to,next;
}a[2000010];
inline int read()
{
	int x=0;
	char s=getchar();
	while(s>'9'||s<'0')
	s=getchar();
	while(s>='0'&&s<='9')
	{
		x=x*10+s-'0';
		s=getchar();
	}
	return x;
}
inline long long ksm(long long x,long long y)
{
	long long res=1;
	while(y)
	{
		if(y&1)
		res=res*x%mod;
		x=x*x%mod;
		y>>=1;
	}
	return res;
}
inline void add(int from,int to)
{
	a[++cnt].to=to;
	a[cnt].next=hed[from];
	hed[from]=cnt;
}
inline void getrt(int x,int f)
{
	sz[x]=1;
	mx[x]=0;
	for(int i=hed[x];i;i=a[i].next)
	{
		int y=a[i].to;
		if(y==f)
		continue;
		getrt(y,x);
		sz[x]+=sz[y];
		mx[x]=max(sz[y],mx[x]);
	}
	mx[x]=max(mx[x],n-sz[x]);
	if(mx[rt1]>mx[x])
	rt1=x;
	else if(mx[x]==mx[rt1])
	rt2=x;
}
inline void dfs(int x)
{
	sz[x]=1;
	for(int i=hed[x];i;i=a[i].next)
	{
		int y=a[i].to;
		if(y==fa[x])
		continue;
		fa[y]=x;
		dfs(y);
		sz[x]+=sz[y];
	}
}
inline int cmp(int x,int y)
{
	return ha[x]<ha[y];
}
inline void dfs1(int x)
{
	int num=0;
	vector<int> v;
	v.clear();
	ha[x]=1;
	for(int i=hed[x];i;i=a[i].next)
	{
		int y=a[i].to;
		if(y==fa[x])
		continue;
		dfs1(y);
	}
	for(int i=hed[x];i;i=a[i].next)
	{
		int y=a[i].to;
		if(y==fa[x])
		continue;
		v.push_back(y);
		++num;
	}
	sort(v.begin(),v.end(),cmp);
	for(int i=0;i<num;++i)
	ha[x]=(ha[x]+ha[v[i]]*mi[i+1])%mod;
	ha[x]=ha[x]*sz[x]%mod;
}
inline long long C(int n,int m)
{
	n%=mod;
	long long res=1;
	for(int i=n-m+1;i<=n;++i)
	res=res*i%mod;
	res=res*ni[m]%mod;
	return res;
}
inline void dfs2(int x)
{
	dp[x][0]=dp[x][1]=1;
	int num=0;
	vector<int> v;
	v.clear();
	for(int i=hed[x];i;i=a[i].next)
	{
		int y=a[i].to;
		if(y==fa[x])
		continue;
		dfs2(y);
		v.push_back(y);
		++num;
	}
	sort(v.begin(),v.end(),cmp);
	int shu=1;
	v.push_back(n+2);
	for(int i=1;i<=num;++i)
	{
		if(ha[v[i]]!=ha[v[i-1]])
		{
			long long qwq=(dp[v[i-1]][0]+dp[v[i-1]][1])%mod,qwqq;
			qwqq=dp[v[i-1]][0];
			dp[x][0]=dp[x][0]*C(qwq+shu-1,shu)%mod;
			dp[x][1]=dp[x][1]*C(qwqq+shu-1,shu)%mod;
			shu=1;
		}
		else
		++shu;
	}
}
int main()
{
	n=read();
	ha[n+2]=mod+2;
	for(int i=1;i<=n-1;++i)
	{
		int x=read(),y=read();
		add(x,y);
		add(y,x);
	}
	mx[0]=2e9;
	getrt(1,0);
	jie[0]=1;
	for(int i=1;i<=n;++i)
	jie[i]=jie[i-1]*i%mod;
	ni[n]=ksm(jie[n],mod-2);
	for(int i=n-1;i>=0;--i)
	ni[i]=ni[i+1]*(i+1)%mod;
	mi[0]=1;
	for(int i=1;i<=n;++i)
	mi[i]=mi[i-1]*base%mod;
	if(mx[rt2]==mx[rt1])
	{
		rt=n+1;
		for(int i=hed[rt1];i;i=a[i].next)
		{
			int y=a[i].to;
			if(y==rt2)
			{
				a[i].to=rt;
				break;
			}
		}
		for(int i=hed[rt2];i;i=a[i].next)
		{
			int y=a[i].to;
			if(y==rt1)
			{
				a[i].to=rt;
				break;
			}
		}
		add(rt,rt1);
		add(rt,rt2);
	}
	else
	rt=rt1;
	memset(sz,0,sizeof(sz));
	dfs(rt);
	dfs1(rt);
	dfs2(rt);
	if(rt==rt1)
	{
		ans=(dp[rt][0]+dp[rt][1])%mod;
		printf("%lld\n",ans);
		return 0;
	}
	if(ha[rt1]==ha[rt2])
	{
		ans=(dp[rt1][0]*dp[rt2][1]%mod+C(dp[rt1][0]+1,2))%mod;
		printf("%lld\n",ans);
	}
	else
	{
		ans=(dp[rt1][0]*dp[rt2][1]%mod+dp[rt1][1]*dp[rt2][0]%mod+dp[rt1][0]*dp[rt2][0]%mod)%mod;
		printf("%lld\n",ans);
	}
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章