後綴數組求不同子串個數

題目鏈接 P2408 不同子串個數


  後綴數組中的sa[i]記錄了後綴排序爲i的後綴,如果它和後綴排序中前一個子串的最長公共前綴是LCP,那麼我們可以把它的總長減去他們的LCP就是答案了,而此時的LCP就是height[i]。

此處用DC3模板:

#include <iostream>
#include <cstdio>
#include <cmath>
#include <string>
#include <cstring>
#include <algorithm>
#include <limits>
#include <vector>
#include <stack>
#include <queue>
#include <set>
#include <map>
#include <bitset>
//#include <unordered_map>
//#include <unordered_set>
#define lowbit(x) ( x&(-x) )
#define pi 3.141592653589793
#define e 2.718281828459045
#define INF 0x3f3f3f3f
#define HalF (l + r)>>1
#define lsn rt<<1
#define rsn rt<<1|1
#define Lson lsn, l, mid
#define Rson rsn, mid+1, r
#define QL Lson, ql, qr
#define QR Rson, ql, qr
#define myself rt, l, r
using namespace std;
typedef unsigned long long ull;
typedef unsigned int uit;
typedef long long ll;
const int maxN = 1e5 + 7;
struct DC3
{
    static const int maxN = 3e5 + 7;
#define F(x) ((x)/3+((x)%3==1?0:tb))
#define G(x) ((x)<tb?(x)*3+1:((x)-tb)*3+2)
    int sa[maxN], rk[maxN], height[maxN], s[maxN];
    int wa[maxN], wb[maxN], wv[maxN], wss[maxN];
    inline int c0(int *r, int a, int b) { return r[a]==r[b] && r[a+1]==r[b+1] && r[a+2]==r[b+2]; }
    inline int c12(int k, int *r, int a, int b)
    {
        if(k==2)
            return r[a]<r[b] || (r[a]==r[b] && c12(1,r,a+1,b+1));
        else return r[a]<r[b] || (r[a]==r[b] && wv[a+1] < wv[b+1]);
    }
    inline void ssort(int *r, int *a, int *b, int n, int m)
    {
        int i;
        for(i=0; i<n; i++) wv[i] = r[a[i]];
        for(i=0; i<m; i++) wss[i] = 0;
        for(i=0; i<n; i++) wss[wv[i]]++;
        for(i=1; i<m; i++) wss[i] += wss[i-1];
        for(i=n-1; i>=0; i--)
            b[--wss[wv[i]]] = a[i];
    }
    inline void dc3(int *r, int *sa, int n, int m)
    {
        int i, j, *rn = r + n;
        int *san = sa + n, ta = 0, tb = (n + 1) / 3, tbc = 0, p;
        r[n] = r[n + 1] = 0;
        for(i=0;i<n;i++) if(i % 3 != 0) wa[tbc++] = i;
        ssort(r+2, wa, wb, tbc, m);
        ssort(r+1, wb, wa, tbc, m);
        ssort(r, wa, wb, tbc, m);
        for(p=1, rn[F(wb[0])] = 0, i = 1; i < tbc; i++)
            rn[F(wb[i])] = c0(r, wb[i-1], wb[i]) ? p - 1 : p++;
        if(p < tbc) dc3(rn, san, tbc, p);
        else for(i=0; i<tbc; i++) san[rn[i]] = i;
        for(i=0; i<tbc; i++) if(san[i]<tb) wb[ta++] = san[i] * 3;
        if(n % 3 == 1) wb[ta++] = n - 1;
        ssort(r, wb, wa, ta, m);
        for(i=0; i<tbc; i++) wv[wb[i]=G(san[i])]=i;
        for(i=0, j=0, p=0; i<ta && j<tbc; p++)
            sa[p]=c12(wb[j]%3, r, wa[i], wb[j]) ? wa[i++] : wb[j++];
        for(; i<ta; p++) sa[p] = wa[i++];
        for(; j<tbc; p++) sa[p] = wb[j++];
    }
    inline void da(int n,int m)
    {
        for(int i=n; i<n*3; i++) s[i]=0;
        dc3(s, sa, n+1, m);
        int i, j, k=0;
        for(i=0;i<=n;i++) rk[sa[i]]=i;
        for(i=0;i<n;i++){
            if(k)k--;
            j=sa[rk[i]-1];
            while(s[i+k]==s[j+k])k++;
            height[rk[i]]=k;
        }
    }
} sa;
char s[maxN];
signed main()
{
    int len;
    scanf("%d", &len);
    scanf("%s", s + 1);
//    int len = (int)strlen(s + 1);
    sa.s[0] = 256;
    for(int i=1; i<=len; i++)
    {
        sa.s[i] = s[i];
    }
    sa.s[len + 1] = 256;
    sa.da(len + 1, 300);
    ll ans = 0;
    for(int i=1, id; i<=len; i++)
    {
        id = sa.sa[i];
        ans += (len - id + 1 - sa.height[i]);
    }
    printf("%lld\n", ans);
    return 0;
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章