###題目大意
給你一個長度爲n的字符串S,求最長的一個字符串序列a[1…k]滿足序列中的每一個字符串都是S的子串,且對於任意的都有a[i−1]在a[i]中至少出現兩次。兩次出現允許重疊。
問最大滿足條件的k是多少。
n<=200000
部分分n<=4000
###解題思路
部分分可以很顯然地設f[i,j]表示a[1]=s[i…j]時的最大k值。我們轉移的時候,可以只轉移s[i…j]的最長border,然後再加多兩個轉移,即轉移到f[i-1,j]和f[i,j+1]。
觀察性質,其實,後面兩個轉移是多餘的。
也就是說,我們只要找到一個子串,他border調用次數最大。
考慮上SAM,在parent樹上做,設F[x]表示節點x所代表的最長子串的k值。
顯然x只能由parent鏈上的祖先轉移過來,由於不能確定是哪個子串,我們不管border,直接用題意的出現次數即可。具體的,我們選擇x的right集裏面一個位置p,假如一個祖先y能轉移到x,那麼y的right集裏面必須要有元素屬於[p-mx_len[x]+mn_len[y],p-1],mnlen代表一個節點所代表的最短子串。轉移過來的話,f值+1
我們可以發現,只需要從祖先中找f值最大的轉移即可。我們設pos[x]來維護x到祖先,最大的f節點是誰。
right集判斷,用可持久化線段樹合併即可,常見的維護right集套路。
注意一定要每次合併都建新點。
###代碼
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<map>
using namespace std;
typedef long long ll;
typedef double db;
#define fo(i,j,k) for(i=j;i<=k;i++)
#define fd(i,j,k) for(i=j;i>=k;i--)
#define cmax(a,b) (a=(a>b)?a:b)
#define cmin(a,b) (a=(a<b)?a:b)
const int N=4e5+5,mo=998244353,xs=1331;
const int M=N*20;
int ts,trie[N][26],f[N],fail[N],mx[N],n,lst[N],i,x,ans,arb[N],ls[M],rs[M],tt,rt[N],d[N],pos[N],le,v;
char s[N];
void change(int &x,int l,int r,int p)
{
if (!x) x=++tt;
if (l==r) return ;
int m=l+r>>1;
if (p<=m) change(ls[x],l,m,p);
else change(rs[x],m+1,r,p);
}
int merge(int x,int ax)
{
if (!x||!ax) return x+ax;
int nx=++tt;
ls[nx]=merge(ls[x],ls[ax]);
rs[nx]=merge(rs[x],rs[ax]);
return nx;
}
int get(int x,int l,int r,int i,int j)
{
if (!x) return 0;
if (l==i&&r==j)
return 1;
int m=l+r>>1;
if (j<=m) return get(ls[x],l,m,i,j);
else if (m<i) return get(rs[x],m+1,r,i,j);
else return get(ls[x],l,m,i,m)+get(rs[x],m+1,r,m+1,j);
}
bool cmp(int x,int y)
{
return mx[x]<mx[y];
}
int ins(int lst,int x)
{
int p,q,np,nq,j;
np=++ts;
mx[np]=mx[lst]+1;
p=lst;
while (p&&!trie[p][x]) trie[p][x]=np,p=fail[p];
if (!p) fail[np]=1;
else
{
q=trie[p][x];
if (mx[q]==mx[p]+1) fail[np]=q;
else
{
nq=++ts;
mx[nq]=mx[p]+1;
while (p&&trie[p][x]==q) trie[p][x]=nq,p=fail[p];
fail[nq]=fail[q];
fail[q]=fail[np]=nq;
fo(j,0,25) trie[nq][j]=trie[q][j];
}
}
change(rt[np],1,n,i);
arb[np]=i;
return np;
}
int main()
{
freopen("t2.in","r",stdin);
// freopen("cat.out","w",stdout);
scanf("%s",s+1);
n=strlen(s+1);
lst[0]=++ts;
rt[1]=++tt;
fo(i,1,n)
lst[i]=ins(lst[i-1],s[i]-'a');
fo(i,1,ts) d[i]=i;
sort(d+1,d+1+ts,cmp);
fd(i,ts,2)
{
x=d[i];
rt[fail[x]]=merge(rt[fail[x]],rt[x]);
cmax(arb[fail[x]],arb[x]);
}
pos[1]=1;
mx[0]=-1;
fo(i,2,ts)
{
x=d[i];
pos[x]=pos[fail[x]];
le=arb[x]-mx[x]+mx[fail[pos[x]]]+1;
v=get(rt[pos[x]],1,n,le,arb[x]-1);
if (v)
{
f[x]=f[pos[x]]+1;
pos[x]=x;
cmax(ans,f[x]);
}
}
printf("%d\n",ans);
}