codeforces 452E. Three strings 後綴數組

題意:給三個串s1,s2,s3,對於每個長度L(1<=L<=min(length(s1,s2,s3)))求有多少個三元組<p1,p2,p3>使得s1[p1...p1+L-1]==s2[p2..p2+L-1]==s3[p3..p3+L-1],求出所有L對應的答案對1e9+7取模。

題解:

先把串連在一起,記錄每個位置所屬的串,跑一下後綴數組求出height數組。

如果從位置pos開始的最大公共子串長L,那麼所有從pos開始的長度小於L的子串都是符合條件的。

這裏可以按height從大到小的順序,一邊計算一邊用並查集合並區間,對於要合併的兩個位置u1,u2,首先他們一定是相鄰。用sum[type][pos]來維護每個位置pos所包含的s1,s2,s3的數量,那麼u1與u2合併以後(以u1爲fa),對答案的貢獻就是

sum[0][u1]*sum[1][u1]*sum[2][u1],因爲涉及重複統計,那麼在相加之前,要先減掉u1和u2中的貢獻。

#include<bits/stdc++.h>
#define rint register int
#define inv inline void
#define ini inline int
#define maxn 3000050
using namespace std;
typedef long long ll;
const ll mod=1000000007;
char s[maxn],t[maxn];
int y[maxn],x[maxn],c[maxn],sa[maxn],rk[maxn],height[maxn],wt[30];
int n,m;
inv putout(int x) {
    if(!x) {
        putchar(48);
        return;
    }
    rint l=0;
    while(x) wt[++l]=x%10,x/=10;
    while(l) putchar(wt[l--]+48);
}
inv get_SA() {
    for (rint i=1; i<=n; ++i) ++c[x[i]=s[i]];
    for (rint i=2; i<=m; ++i) c[i]+=c[i-1];
    for (rint i=n; i>=1; --i) sa[c[x[i]]--]=i;
    for (rint k=1; k<=n; k<<=1) {
        rint num=0;
        for (rint i=n-k+1; i<=n; ++i) y[++num]=i;
        for (rint i=1; i<=n; ++i) if (sa[i]>k) y[++num]=sa[i]-k;
        for (rint i=1; i<=m; ++i) c[i]=0;
        for (rint i=1; i<=n; ++i) ++c[x[i]];
        for (rint i=2; i<=m; ++i) c[i]+=c[i-1];
        for (rint i=n; i>=1; --i) sa[c[x[y[i]]]--]=y[i],y[i]=0;
        swap(x,y);
        x[sa[1]]=1;
        num=1;
        for (rint i=2; i<=n; ++i)
            x[sa[i]]=(y[sa[i]]==y[sa[i-1]] && y[sa[i]+k]==y[sa[i-1]+k]) ? num : ++num;
        if (num==n) break;
        m=num;
    }
}
inv get_height() {
    rint k=0;
    for (rint i=1; i<=n; ++i) rk[sa[i]]=i;
    for (rint i=1; i<=n; ++i) {
        if (rk[i]==1) continue;//第一名height爲0
        if (k) --k;//h[i]>=h[i-1]-1;
        rint j=sa[rk[i]-1];
        while (j+k<=n && i+k<=n && s[i+k]==s[j+k]) ++k;
        height[rk[i]]=k;//h[i]=height[rk[i]];
    }
}
int p[maxn],belong[maxn],fa[maxn];
ll sum[4][maxn],ans[maxn];
char ss[maxn];
bool cmp(int x,int y){
    return height[x]>height[y];
}
inline int findfa(int u){
    if(fa[u]==u) return u;
    else return fa[u]=findfa(fa[u]);
}
int main() {
    scanf("%s",ss+1);
    int len=strlen(ss+1);
    int aa=len;
    n=0;
    for(int i=1;i<=len;i++) belong[++n]=0,s[n]=ss[i];
    s[++n]=1;
    belong[n]=-1;
    scanf("%s",ss+1);
    len=strlen(ss+1);aa=min(aa,len);
    for(int i=1;i<=len;i++) belong[++n]=1,s[n]=ss[i];
    s[++n]=2;
    belong[n]=-1;
    scanf("%s",ss+1);
    len=strlen(ss+1);aa=min(aa,len);
    for(int i=1;i<=len;i++) belong[++n]=2,s[n]=ss[i];
    s[++n]=3;
    belong[n]=-1;
    m=200;
    get_SA();
    get_height();
    for(int i=1;i<=n;i++){
        //printf("@%d\n",height[i]);
        p[i]=i;
        fa[i]=i;
        sum[0][i]=sum[1][i]=sum[2][i]=0;
        if(belong[i]==0) sum[0][i]=1;
        if(belong[i]==1) sum[1][i]=1;
        if(belong[i]==2) sum[2][i]=1;
    }
    sort(p+1,p+1+n,cmp);

    int j=1;
    ll tmp=0;
    for(int i=aa;i>=1;i--){
        while(j<=n&&height[p[j]]>=i){
            //printf("!%d %d\n",sa[p[j]],p[j]);
            int l=findfa(sa[p[j]-1]);
            int r=findfa(sa[p[j]]);
            tmp=(tmp+mod-(sum[0][l]*sum[1][l]%mod*sum[2][l]%mod))%mod;
            tmp=(tmp+mod-(sum[0][r]*sum[1][r]%mod*sum[2][r]%mod))%mod;
            sum[0][l]+=sum[0][r];
            sum[1][l]+=sum[1][r];
            sum[2][l]+=sum[2][r];
            tmp=(tmp+(sum[0][l]*sum[1][l]%mod*sum[2][l]%mod))%mod;
            fa[r]=l;
            j++;
        }
        tmp%=mod;
        ans[i]=tmp;
    }
    for(int i=1;i<=aa;i++) printf("%lld ",ans[i]);
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章