題目大意:
給你一個S串和m個T串,
求S串中有多少子串不是另外m個T串的子串。
解題思路:
對於所有的串,首先將他們連接。接下來我們就要統計對於0到 strlen(s) 這些位置,每個位置有多少與它重複的子串,統計出來之後拿串的總個數減去重複的即可。
統計重複的就從前往後掃一遍 對於在S串和在T串的情況考慮。同時也要從後往前掃一遍。
最後在把S串中自己重複的子串去掉即可。
但是這題讓我明白了好多小細節,之前我對於小寫字母都是直接-'a',但是這個題目就會產生問題,其次每個串之間的分隔符,之前用的一直都是一個沒有出現過的符號插中間,這個題目也會出現問題。可能是我對後綴數組理解的還不夠深刻。。。
Ac代碼:
#include<bits/stdc++.h>
#define rank ra
using namespace std;
const int maxn=3e5+10;
const int INF=1e9+7;
typedef long long ll;
char s[maxn];
int n,sa[maxn],rank[maxn],height[maxn],pos[maxn];
int t1[maxn],t2[maxn],r[maxn],c[maxn];
bool cmp(int *r,int a,int b,int l)
{
return r[a]==r[b] && r[a+l]==r[b+l];
}
void da(int str[],int sa[],int rank[],int height[],int n,int m)
{
n++;
int i,j,p,*x=t1,*y=t2;
for(int i=0;i<m;i++) c[i]=0;
for(int i=0;i<n;i++) c[x[i]=str[i]]++;
for(int i=1;i<m;i++) c[i]+=c[i-1];
for(int i=n-1;i>=0;i--) sa[--c[x[i]]]=i;
for(int j=1;j<=n;j<<=1)
{
p=0;
for(int i=n-j;i<n;i++) y[p++]=i;
for(int i=0;i<n;i++) if(sa[i]>=j) y[p++]=sa[i]-j;
for(int i=0;i<m;i++) c[i]=0;
for(int i=0;i<n;i++) c[x[y[i]]]++;
for(int i=1;i<m;i++) c[i]+=c[i-1];
for(int i=n-1;i>=0;i--) sa[--c[x[y[i]]]]=y[i];
swap(x,y);
p=1,x[sa[0]]=0;
for(int i=1;i<n;i++)
x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++;
if(p>=n) break;
m=p;
}
int k=0;
n--;
for(int i=0;i<=n;i++) rank[sa[i]]=i;
for(int i=0;i<n;i++)
{
if(k) k--;
j=sa[rank[i]-1];
while(str[i+k]==str[j+k]) k++;
height[rank[i]]=k;
}
}
int main()
{
int QAQ,kase=0;
scanf("%d",&QAQ);
while(QAQ--)
{
int m; scanf("%d",&m);
scanf(" %s",s);
int ls=strlen(s),len=0,num=30;
for(int i=0;i<ls;i++) r[len++]=s[i]-'a'+1; //這裏注意+1
while(m--)
{
r[len++]=num++; //注意分隔符爲num++
scanf(" %s",s);
int lk=strlen(s);
for(int i=0;i<lk;i++) r[len++]=s[i]-'a'+1;
}
r[len]=0;n=len;
da(r,sa,rank,height,n,num+1);
memset(pos,0,sizeof pos);
int tmp=INF;
for(int i=1;i<=n;i++) //從前往後統計重複子串
{
if(sa[i]<ls)
{
tmp=min(tmp,height[i]);
pos[sa[i]]=max(pos[sa[i]],tmp);
}
else tmp=INF;
}
tmp=INF;
for(int i=n;i>=1;i--) //從後往前統計重複子串
{
if(sa[i-1]<ls)
{
tmp=min(tmp,height[i]);
pos[sa[i-1]]=max(pos[sa[i-1]],tmp);
}
else tmp=INF;
}
for(int i=1;i<=n;i++) //統計自己與自己重複的子串
{
if(sa[i]<ls&&sa[i-1]<ls)
{
pos[sa[i-1]]=max(pos[sa[i-1]],height[i]);
}
}
ll res=1LL*ls*(ls+1)/2;
for(int i=0;i<ls;i++)
res-=pos[i];
printf("Case %d: %lld\n",++kase,res);
}
return 0;
}