分析:
其實題目就是要求任意兩個後綴T[i]和T[j] (i< j) 的 LCP長度之和。
首先對輸入的字符串反轉後建立SAM。
令 一個節點的Max表示它代表的最長子串。
推論1: 原串中的兩個後綴的LCP長度等於後綴樹上兩個節點的LCA 的 Max.
基於這個推論,只需要枚舉LCA,討論子樹與子樹之間的組合數問題。
雖然理論上要建立後綴樹,但代碼裏並不用真的建一棵樹再DFS。
具體做法如下:
推論2: Max更大的節點一定在後綴樹中深度更大。
根據這個推論,我們可以按照Max從小到大進行排序,得到一個序列dfn[]就是深度從淺到深的節點編號,可以進行DP。
考慮到1<=Max<=n, 代碼中用的是O(n)的基數排序。
注意: DP中累加size的時候,建立的輔助節點nq沒有算在其中。
代碼:
/**************************************************************
Problem: 3238
User: spark
Language: C++
Result: Accepted
Time:3072 ms
Memory:244452 kb
****************************************************************/
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<cmath>
#include<cstring>
#define LL long long
using namespace std;
const int maxn=1000000+5;
int tot=1,last=1,n,dfn[maxn],cnt[maxn];
char s[maxn];
struct node{
int Next[26];
int Max,size,pre,sons;
}T[maxn<<1];
template <typename T>
inline void _read(T &x){
char ch=getchar(); bool mark=false;
for(;!isdigit(ch);ch=getchar())if(ch=='-')mark=true;
for(x=0;isdigit(ch);ch=getchar())x=x*10+ch-'0';
if(mark)x=-x;
}
void Insert(char x){
int id= x-'a';
int np= ++tot,cur=last;
T[np].Max=T[last].Max+1; T[np].size=T[np].sons=1;
while(cur){
if(!T[cur].Next[id]) T[cur].Next[id]= np;
else {
int v= T[cur].Next[id];
if(T[v].Max==T[cur].Max+1) T[np].pre=v;
else{
int nq= ++tot;
memcpy(T[nq].Next,T[v].Next,sizeof(T[v].Next));
T[nq].Max=T[cur].Max+1;
T[nq].pre=T[v].pre; T[v].pre= nq;
T[np].pre= nq;
for(int i= cur;T[i].Next[id]==v;i=T[i].pre)
T[i].Next[id]=nq;
}
break;
}
cur=T[cur].pre;
}
if(!T[np].pre)T[np].pre=1;
last=np;
}
LL Count(){
int i;
LL ans=0;
for(i=1;i<=tot;i++) cnt[T[i].Max]++;
for(i=1;i<=n;i++)cnt[i]+=cnt[i-1];
for(i=1;i<=tot;i++)dfn[cnt[T[i].Max]--]= i;
// cout<<"dfn: ";for(i=1;i<=tot;i++)cout<<dfn[i]<<" ";cout<<endl;
for(i=tot;i>0;i--)T[T[dfn[i]].pre].size+=T[dfn[i]].size;
for(i=1;i<=tot;i++){
int fa= T[i].pre;
ans+= 1ll*T[fa].sons*T[fa].Max*T[i].size;
T[fa].sons+= T[i].size;
}
return ans;
}
int main(){
int i,j;
scanf("%s",s+1);
n= strlen(s+1);
for(i=n;i>0;i--)Insert(s[i]);
LL ans=0;
for(i=1;i<=n;i++) ans+= (1ll*i*(n-1));
cout<<ans-2*Count()<<endl;
return 0;
}
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#define Marx_is_dead true
#define ll long long
using namespace std;
template <typename T>
inline void _read(T& x){
char t=getchar();bool sign=true;
while(t<'0'||t>'9'){if(t=='-')sign=false;t=getchar();}
for(x=0;t>='0'&&t<='9';t=getchar())x=x*10+t-'0';
if(!sign)x=-x;
}
int n,m;
int last=1,root=1,tot=1;
struct node{
int son[26];
int maxn,size,par,temp;
};
node Auto[2000005];
char s[1000005];
void insert(char ch){
int i,j,k,nq,p,q,t=ch-'a';
int np=++tot;
Auto[np].size=1;
Auto[np].temp=1;
Auto[np].maxn=Auto[last].maxn+1;
for(p=last;!Auto[p].son[t];p=Auto[p].par)Auto[p].son[t]=np;
if(!p)Auto[np].par=root;
else {
q=Auto[p].son[t];
if(Auto[q].maxn!=Auto[p].maxn+1){
nq=++tot;
memcpy(Auto[nq].son,Auto[q].son,sizeof(Auto[q].son));
//Auto[nq]=Auto[q];
Auto[nq].maxn=Auto[p].maxn+1;
Auto[nq].par=Auto[q].par;
Auto[q].par=nq;
Auto[np].par=nq;
for(;Auto[p].son[t]==q;p=Auto[p].par)Auto[p].son[t]=nq;
}
else Auto[np].par=q;
}
if(Auto[np].par==0)Auto[np].par=root;
last=np;
}
int cnt[1000005];
int dfn[1000005];
int main(){
int i,j,len;
scanf("%s",s+1);
len=strlen(s+1);
for(i=len;i;i--){
insert(s[i]);
}
ll ans=1ll*(len+1)*(len-1)*len/2;
for(i=1;i<=tot;i++)cnt[Auto[i].maxn]++;
for(i=1;i<=len;i++)cnt[i]+=cnt[i-1];
for(i=1;i<=tot;i++)dfn[cnt[Auto[i].maxn]--]=i;
for(i=tot;i;i--){
Auto[Auto[dfn[i]].par].size+=Auto[dfn[i]].size;
}
ll dec=0;
for(i=1;i<=tot;i++){
int fa=Auto[i].par;
dec+=1ll*Auto[fa].temp*Auto[i].size*Auto[fa].maxn;
Auto[fa].temp+=Auto[i].size;
}
cout<<ans-2*dec;
}