http://acm.hdu.edu.cn/showproblem.php?pid=4436
給出n個串,問這些串中所有不同的子串可組成的數字之和模2012的結果是多少?
將n個串接到一起,中間用分隔符隔開,然後求長串的後綴數組,解height數組
做過後綴數組的都知道,我們想知道他們中的不同的子串是可以的,但想想這樣其實是個暴力的方法,嘗試了下果斷是TLE的。
但這可以給我們之後的解題提供思路,先來看看怎麼找到所有不同子串:
height數組表示的意思是排在第i位的後綴,和排在第i-1位的後綴的最長公共前綴的長度,
比如:
height[i]=3
第i-1位:[123]45$....
第i位: [123]567$....
然後我們在統計第i位開始的不同子串時,我們是不需要計算1,12,123的,新的子串只有1235,12356,123567。
所以第i位開始可構成的新的子串就是從原串的sa[i]+height[i]開始,一直到遇到$爲止的那些串,我們把他們加到一起就可以得到答案了。
但是,簡單想想,這些不同的子串太多了,這樣必然是TLE。我們要加快計算的過程!
我們計len[i]表示從第i個字符往後,到遇到第一個$時的長度;
我們計sum[i]表示,以第i個字符開始的,直到遇到第一個$的所有子串的和;如123$....,sum[0]=1+12+123=134
這時在計算第i位的時候,sum[i]=1+12+123+1235+12356+123567,sum[i+height[i]]=5+56+567
我們希望只把1235,12356,123567加進去,實際上就是(123*10+5)+(123*100+56)+(123*1000+567)=5+56+567+123*1110
我們可以記錄當前公共部分的值num[i]=123,然後將其乘上1110(len[sa[i]+height[i]]個1)再加上sum[i+height[i]]就行了!
後綴數組解法代碼:
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn = 200000;
const int mod = 2012;
int wa[maxn],wb[maxn],wv[maxn],ws[maxn];
int cmp(int *r,int a,int b,int l)
{return r[a]==r[b]&&r[a+l]==r[b+l];}
void da(int *r,int *sa,int n,int m)
{
int i,j,p,*x=wa,*y=wb,*t;
for(i=0;i<m;i++) ws[i]=0;
for(i=0;i<n;i++) ws[x[i]=r[i]]++;
for(i=1;i<m;i++) ws[i]+=ws[i-1];
for(i=n-1;i>=0;i--) sa[--ws[x[i]]]=i;
for(j=1,p=1;p<n;j*=2,m=p)
{
for(p=0,i=n-j;i<n;i++) y[p++]=i;
for(i=0;i<n;i++) if(sa[i]>=j) y[p++]=sa[i]-j;
for(i=0;i<n;i++) wv[i]=x[y[i]];
for(i=0;i<m;i++) ws[i]=0;
for(i=0;i<n;i++) ws[wv[i]]++;
for(i=1;i<m;i++) ws[i]+=ws[i-1];
for(i=n-1;i>=0;i--) sa[--ws[wv[i]]]=y[i];
for(t=x,x=y,y=t,p=1,x[sa[0]]=0,i=1;i<n;i++)
x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++;
}
return;
}
int n;
int height[maxn],rank[maxn];
char str[maxn],s[maxn];
int r[maxn];
int sa[maxn];
void get_height(){
int k=0,i,j;
for(i=1;i<=n;i++)rank[sa[i]]=i;
for(i=0;i<n;height[rank[i++]]=k)
for(k?k--:0,j=sa[rank[i]-1];r[i+k]==r[j+k];k++);
for(int i=0;i<n;i++)height[i]=height[i+1];
for(int i=0;i<n;i++)sa[i]=sa[i+1];
}
int num[maxn];
int sum[maxn],len[maxn];
int pow_mod(int a,int b){
int ans=1;
while(b){
if(b&1)ans=ans*a%mod;
a=a*a%mod;
b=b/2;
}
return ans;
}
int pow[maxn];
int get_mul(int a,int l){
return a*pow[l]%mod;
}
int main(){
int b=1;
pow[0]=0;
for(int i=1;i<maxn;i++){
b=b*10%mod;
pow[i]=(pow[i-1]+b)%mod;
}
int m;
while(~scanf("%d",&m)){
int lable=11;
n=0;
for(int i=0;i<m;i++){
scanf("%s",str);
for(int j=0;str[j];j++){
r[n++]=str[j]-'0'+1;
}
r[n++]=lable++;
}
r[n]=0;
da(r,sa,n+1,lable);
get_height();
int ans=0;
int top=0;
sum[n]=0;
for(int i=n-1;i>=0;i--){
if(r[i]>10||r[i]==0){sum[i]=0;len[i]=0;}
else{
len[i]=len[i+1]+1;
sum[i]=(sum[i+1]+get_mul(r[i]-1,len[i+1])+r[i]-1)%mod;
}
}
if(r[sa[0]]<=10&&r[sa[0]]!=1)ans=sum[sa[0]]%mod;
num[0]=0;
top=0;
for(int i=sa[0];r[i]<=10&&r[i];i++){
num[top+1]=(num[top]*10+r[i]-1)%mod;
top++;
}
for(int i=1;i<n;i++){
if(r[sa[i]]<=1||r[sa[i]]>10){top=0;continue;}
for(int j=top+1;j<=height[i];j++){
num[j]=(num[j-1]*10+r[sa[i]+j-1]-1)%mod;
}
top=height[i];
ans=(ans+sum[sa[i]+height[i]]+get_mul(num[top],len[sa[i]+height[i]]))%mod;
}
printf("%d\n",ans);
}
}
後綴自動機:
http://blog.csdn.net/acm_cxlove/article/details/8234200
#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
const int maxn = 220000;
const int mod = 2012;
struct Node{
Node* f,*ch[11];
int len,cnt,sum;
}node[maxn],*init,*tail,*que[maxn];
int top;
void add(int c,int len){
Node *p=tail,*np=&node[top++];
np->len=len;
for(;p&&!p->ch[c];p=p->f)p->ch[c]=np;
tail=np;
if(!p)np->f=init;
else if(p->ch[c]->len==p->len+1)np->f=p->ch[c];
else{
Node *q=p->ch[c],*r=&node[top++];
*r=*q;
r->len=p->len+1;
q->f=np->f=r;
for(;p&&p->ch[c]==q;p=p->f)p->ch[c]=r;
}
}
char str[maxn/2];
int c[maxn],len;
int main(){
int n;
while(~scanf("%d",&n)){
memset(c,0,sizeof(c));
memset(node,0,sizeof(node));
top=1;
init=tail=&node[top++];
len=0;
for(int i=0;i<n;i++){
scanf("%s",str);
for(int j=0;str[j];j++){
add(str[j]-'0',++len);
}
add(10,++len);
}
for(int i=1;i<top;i++)c[node[i].len]++;
for(int i=1;i<=len;i++)c[i]+=c[i-1];
for(int i=1;i<top;i++)que[c[node[i].len]--]=&node[i];
init->cnt=1;
for(int i=1;i<top;i++){
Node* p=que[i];
for(int j=0;j<10;j++){
if(i==1&&j==0)continue;
if(p->ch[j]){
Node* q=p->ch[j];
q->cnt=(q->cnt+p->cnt)%mod;
q->sum=(q->sum+p->sum*10+j*p->cnt)%mod;
}
}
}
int ans=0;
for(int i=1;i<top;i++){
ans=(ans+node[i].sum)%mod;
}
printf("%d\n",ans);
}
}