題目鏈接
半前綴定義
從這個前綴中刪去一個子串(或者也可以不刪去),使得半前綴爲。當然,本題中,半前綴可以是空串。也可以是個前綴。
我們定義半前綴爲。
求解半前綴
首先,假設和是兩個相同的半前綴,並且,於是有這部分一定是等於因爲前綴是連續的,說明可以用j後面的部分來進行替代。
那麼,什麼時候不存在這樣的呢?通過等於分析得:在[i + 1, n]找到一個子串,該子串的首字母不是s[i + 1]就可以滿足不存在等於的結果。
所以,問題就變成了當我們枚舉i,然後找[i + 1, n]不以s[i + 1]開頭的不同的子串的個數——這裏特別強調“不同的子串”。
求解[i + 1, n]不同的子串的個數
問題的最後一個關鍵就在於求解[i + 1, n]的不同的子串的個數了,並且最好是O(N)的辦法來進行求解的。
看到這個問題,我們可以將其變化爲求解某後綴和在字符串位置中在它後面的後綴有多少個最長的公共前綴。那麼,只用在最後面的時候算一遍就可以了,起到了去重的作用。
算某個排名爲rk的後綴和sa在它之後的後綴的最長公共前綴的長度,我們可以用單調隊列來解決。
我們維護一個sa單調遞減的單調隊列,當一個原本在隊列內的元素要出隊的時候,就說明要進來一個sa比它大的,且rk排名在它之後的元素,它和現在的元素,以及在它之後進隊、之前出隊的元素構成了一個LCP=min(height)的形式。所以可以求得它和排名在它之後的且sa在它之後的元素對它的最長公共前綴的長度。
但是我們別忘了求排名在它之前的且sa在它之後的元素對它的最長公共前綴的長度,這時候,反着跑一遍單調隊列就可以了。
Code
#include <iostream>
#include <cstdio>
#include <cmath>
#include <string>
#include <cstring>
#include <algorithm>
#include <limits>
#include <vector>
#include <stack>
#include <queue>
#include <set>
#include <map>
#include <bitset>
//#include <unordered_map>
//#include <unordered_set>
#define lowbit(x) ( x&(-x) )
#define pi 3.141592653589793
#define e 2.718281828459045
#define INF 0x3f3f3f3f
#define HalF (l + r)>>1
#define lsn rt<<1
#define rsn rt<<1|1
#define Lson lsn, l, mid
#define Rson rsn, mid+1, r
#define QL Lson, ql, qr
#define QR Rson, ql, qr
#define myself rt, l, r
using namespace std;
typedef unsigned long long ull;
typedef unsigned int uit;
typedef long long ll;
const int maxN = 1e6 + 7;
struct DC3
{
static const int maxN = 3e6 + 7;
#define F(x) ((x)/3+((x)%3==1?0:tb))
#define G(x) ((x)<tb?(x)*3+1:((x)-tb)*3+2)
int sa[maxN], rk[maxN], height[maxN], s[maxN];
int wa[maxN], wb[maxN], wv[maxN], wss[maxN];
inline int c0(int *r, int a, int b) { return r[a]==r[b] && r[a+1]==r[b+1] && r[a+2]==r[b+2]; }
inline int c12(int k, int *r, int a, int b)
{
if(k==2)
return r[a]<r[b] || (r[a]==r[b] && c12(1,r,a+1,b+1));
else return r[a]<r[b] || (r[a]==r[b] && wv[a+1] < wv[b+1]);
}
inline void ssort(int *r, int *a, int *b, int n, int m)
{
int i;
for(i=0; i<n; i++) wv[i] = r[a[i]];
for(i=0; i<m; i++) wss[i] = 0;
for(i=0; i<n; i++) wss[wv[i]]++;
for(i=1; i<m; i++) wss[i] += wss[i-1];
for(i=n-1; i>=0; i--)
b[--wss[wv[i]]] = a[i];
}
inline void dc3(int *r, int *sa, int n, int m)
{
int i, j, *rn = r + n;
int *san = sa + n, ta = 0, tb = (n + 1) / 3, tbc = 0, p;
r[n] = r[n + 1] = 0;
for(i=0;i<n;i++) if(i % 3 != 0) wa[tbc++] = i;
ssort(r+2, wa, wb, tbc, m);
ssort(r+1, wb, wa, tbc, m);
ssort(r, wa, wb, tbc, m);
for(p=1, rn[F(wb[0])] = 0, i = 1; i < tbc; i++)
rn[F(wb[i])] = c0(r, wb[i-1], wb[i]) ? p - 1 : p++;
if(p < tbc) dc3(rn, san, tbc, p);
else for(i=0; i<tbc; i++) san[rn[i]] = i;
for(i=0; i<tbc; i++) if(san[i]<tb) wb[ta++] = san[i] * 3;
if(n % 3 == 1) wb[ta++] = n - 1;
ssort(r, wb, wa, ta, m);
for(i=0; i<tbc; i++) wv[wb[i]=G(san[i])]=i;
for(i=0, j=0, p=0; i<ta && j<tbc; p++)
sa[p]=c12(wb[j]%3, r, wa[i], wb[j]) ? wa[i++] : wb[j++];
for(; i<ta; p++) sa[p] = wa[i++];
for(; j<tbc; p++) sa[p] = wb[j++];
}
inline void da(int n,int m)
{
for(int i=n; i<n*3; i++) s[i]=0;
dc3(s, sa, n+1, m);
int i, j, k=0;
for(i=0;i<=n;i++) rk[sa[i]]=i;
for(i=0;i<n;i++)
{
if(k)k--;
j=sa[rk[i]-1];
while(s[i+k]==s[j+k])k++;
height[rk[i]]=k;
}
}
} sa;
char s[maxN];
int Stap[maxN], Stop, max_lcp[maxN], lcp[maxN];
ll suff_sum[32] = {0};
signed main()
{
scanf("%s", s + 1);
int len = (int)strlen(s + 1);
sa.s[0] = 28;
for(int i=1; i<=len; i++)
{
sa.s[i] = s[i] - 'a' + 1;
}
sa.s[len + 1] = 28;
sa.da(len + 2, 30);
// for(int i=1; i<=len; i++) printf("%d%c", sa.sa[i], i == len ? '\n' : ' ');
// for(int i=1; i<=len; i++) printf("%d%c", sa.height[i], i == len ? '\n' : ' ');
Stop = 0;
for(int i=1, now; i<=len; i++)
{
now = sa.height[i];
while(Stop > 0 && Stap[Stop] < sa.sa[i])
{
max_lcp[Stap[Stop]] = max(max_lcp[Stap[Stop]], now);
now = min(lcp[Stap[Stop]], now);
Stop--;
}
Stap[++Stop] = sa.sa[i];
lcp[Stap[Stop]] = now;
}
Stop = 0;
for(int i=len, now; i; i--)
{
now = sa.height[i + 1];
while(Stop > 0 && Stap[Stop] < sa.sa[i])
{
max_lcp[Stap[Stop]] = max(max_lcp[Stap[Stop]], now);
now = min(lcp[Stap[Stop]], now);
Stop--;
}
Stap[++Stop] = sa.sa[i];
lcp[Stap[Stop]] = now;
}
ll ans = 0, all = 0;
for(int i=len; i>=0; i--)
{
ans = ans + 1LL + all - suff_sum[sa.s[i + 1]];
all = all + len - i + 1 - max_lcp[i];
suff_sum[sa.s[i]] += len - i + 1 - max_lcp[i];
}
printf("%lld\n", ans);
return 0;
}