codeforces 452E Three strings 後綴數組+並查集

      給三個串s1,s2,s3,對於每個長度L(1<=L<=min(length(s1,s2,s3)))求有多少個三元組<p1,p2,p3>使得s1[p1...p1+L-1]==s2[p2..p2+L-1]==s3[p3..p3+L-1],求出所有L對應的答案對1e9+7取模。

       首先還是將三個串拼起來並標記出每個位置屬於哪個串,求出後綴數組(sa)和相鄰的最長公共前綴數組(height),容易想到遍歷height數組來求出答案,一種比較暴力的做法是,枚舉長度L,對於每個L掃描一遍height數組,按照本次的長度L分組,若該組中包含num1個1串的位置,num2個2串的位置,num3個3串的位置,那麼由乘法原理就對L的答案加上num1*num2*num3,直到掃面完一遍height數組,便可得到L的答案。

但是這麼做複雜度太高了...並且這樣統計會有很多區間被重複統計從而浪費時間,假設對於一個height={10,10,10,10,10,10,10,10,10,10},我們只需計算一次就可以得出它隊L=1..10所有情況答案的貢獻,而枚舉長度的話就要枚舉十次.....這裏可以按height從大到小的順序,一遍計算一邊合併區間,對於要合併的兩個區間u1,u2,首先他們一定是相鄰的,我們可以用三個數組來維護每個區間所包含的s1,s2,s3的位置的數量,那麼u1與u2合併以後,對答案的貢獻就是兩個區間s1位置的數量之和*s2的*s3的,因爲涉及重複統計,那麼在相加之前,要先減掉u1和u2中的答案。具體的細節見下邊代碼了。

/*=============================================================================
#  Author:Erich
#  FileName:
=============================================================================*/
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include <map>
#include <queue>
#include <stack>
#define lson id<<1,l,m
#define rson id<<1|1,m+1,r

using namespace std;
typedef long long ll;
const int inf=0x3f3f3f3f;
const ll INF=1ll<<60;
const double PI=acos(-1.0);
const int maxn=505000;
const int mod=1e9+7;
char ss[maxn];
int belong[maxn];
int s[maxn],rs[maxn];  
int sa[maxn],t[maxn],t2[maxn],c[maxn];  
int n,m,k,tt;  
int rank[maxn],height[maxn];  
int len,l;
inline int idx(char c)
{
	return c-'a'+3;
}
inline char fdx(int x)
{
	return char(x-3+'a');
}
void getheight(int n)  
{  
    int i,j,k=0;  
    for (i=0; i<=n; i++) rank[sa[i]]=i;  
  
    for (i=0; i<n; i++)  
    {  
        if (k) k--;  
        int j=sa[rank[i]-1];  
        while(s[i+k]==s[j+k]) k++;  
        height[rank[i]]=k;  
    }  
}  
void build_ss(int m,int n)  
{  
    n++;  
    int i,*x=t,*y=t2;  
    for (int i=0; i<m; i++) c[i]=0;  
    for (int i=0; i<n; i++) c[x[i]=s[i]]++;  
    for (int i=1; i<m; i++) c[i]+=c[i-1];  
    for (int i=n-1; i>=0; i--)  
      sa[--c[x[i]]]=i;  
    for (int k=1; k<=n; k<<=1)  
    {  
        int p=0;  
        for (i=n-k; i<n; i++) y[p++]=i;  
        for (i=0; i<n; i++) if (sa[i]>=k) y[p++]=sa[i]-k;  
  
        for (i=0; i<m; i++) c[i]=0;  
        for (i=0; i<n; i++) c[x[y[i]]]++;  
        for (i=1; i<m; i++) c[i]+=c[i-1];  
        for (i=n-1; i>=0; i--) sa[--c[x[y[i]]]] = y[i];  
        swap(x,y);  
        p=1;  
        x[sa[0]]=0;  
        for (i=1; i<n; i++)  
        x[sa[i]]=(y[sa[i-1]]==y[sa[i]] && y[sa[i-1]+k]==y[sa[i]+k])? p-1 : p++;  
        if (p>=n) break;  
        m=p;  
    }  
}
void print(int x)
{
	for (int i=x; i<=n; i++)
	if (s[i]>2)cout<<fdx(s[i]);
	else cout<<s[i];
	cout<<endl;
}
ll ans[maxn];
ll sum[4][maxn];
int fa[maxn];
struct node
{
	int pos,h;
	node(){
	}
	node(int x,int y)
	{
		pos=x; h=y;
	}
	bool operator<(const node&p) const
	{
		return h<p.h;
	}
};
int find(int x)
{
	if (fa[x]==x) return x;
	return fa[x]=find(fa[x]);
}
bool cmp(int x,int y)
{
	return height[x]>height[y];
}
int p[maxn];

int main()
{
	//freopen("in.txt","r",stdin);
	while(~scanf("%s",&ss))
	{
		l=strlen(ss);
		len=l;
		n=0;
		for (int i=0; i<l;i++)	belong[n]=0,s[n++]=idx(ss[i]);
		belong[n]=-1;s[n++]=1;
		scanf("%s",&ss); l=strlen(ss);
		len=min(len,l);
		for (int i=0; i<l; i++) belong[n]=1,s[n++]=idx(ss[i]);
		belong[n]=-1;s[n++]=2;
		scanf("%s",&ss); l=strlen(ss);
		len=min(len,l);
		for (int i=0; i<l; i++) belong[n]=2,s[n++]=idx(ss[i]);
		belong[n]=-1;s[n]=0;
		build_ss(33,n);
		getheight(n);

//	for (int i=0; i<=n; i++)
//		print(sa[i]);
	//	for (int i=1; i<=n; i++) printf("%d%s",height[i],i==n?"\n":" ");	
		for (int i=0; i<=n; i++) p[i]=i,fa[i]=i;
		memset(sum,0,sizeof sum);
		sort(p+1,p+1+n,cmp);
		for (int i=0; i<=n; i++)
		{
			sum[0][i]=sum[1][i]=sum[2][i]=0;
			if (belong[i]==0) sum[0][i]=1;
			if (belong[i]==1) sum[1][i]=1;
			if (belong[i]==2) sum[2][i]=1;
		}
		int j=1;
		ll tmp=0;
		memset(ans,0,sizeof ans);
		for (int i=len; i>=1; i--)
		{
			
			while(j<=n && height[p[j]]>=i)
			{
				int le=find(sa[p[j]-1]),ri=find(sa[p[j]]);
				tmp=tmp-(sum[0][le]%mod*sum[1][le]%mod*sum[2][le]%mod);
				tmp=((tmp%mod)+mod)%mod;
			       	tmp=tmp-(sum[0][ri]%mod*sum[1][ri]%mod*sum[2][ri]%mod);
				tmp=((tmp%mod)+mod)%mod;
				sum[0][le]+=sum[0][ri];
				sum[1][le]+=sum[1][ri];
				sum[2][le]+=sum[2][ri];
				tmp=tmp+(sum[0][le]%mod*sum[1][le]%mod*sum[2][le]%mod);
				tmp%=mod;
				fa[ri]=le;
				j++;
			}
			ans[i]=tmp;
		}
		for (int i=1; i<=len; i++)
			cout<<ans[i]<<" ";
		cout<<endl;
	}
	return 0;
}


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章