poj 3415 Common Substrings(後綴數組+單調棧)

題意:給2個串,求2個串的後綴之間公共串長度>=k的對數。

做法:單調棧維護3個東西,一個是height,一個是在之上到前一個元素(其實包括這個前一個元素)之間另一個串後綴的個數,還有個位置。用dp[i]代表rank爲第i個串與之上所有另一個串後綴>=k的對數。那麼在加入新串的時候,先出棧height大於當前的,同時統計這之間有多少另一個串後綴的個數x,然後一直出到小於當前height爲止,設這個位置爲j,那麼dp[i] = dp[st[j][3]]+x*(height[i]-k+1)。如果當前這個串是自己這個串的,那麼就ans += dp[i]。然後再把當前的串加入棧中。

其實總結下,單調棧中相鄰的元素中間的數字都可以全部認爲是後面一個元素的height,因爲對於相同的height來說它是最後一個,中間的height肯定只可能比它要大(如果比它小那麼前一個就是這個了),因爲傳遞性的關係,對於下面的串來說,那段都可以被認爲height全等於後面一個元素的。有了這個想法就比較容易理解怎麼做的了。

AC代碼:

//#pragma comment(linker, "/STACK:102400000,102400000")
#include<cstdio>
#include<ctype.h>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<vector>
#include<cstdlib>
#include<stack>
#include<queue>
#include<set>
#include<map>
#include<cmath>
#include<ctime>
#include<string.h>
#include<string>
#include<sstream>
#include<bitset>
using namespace std;
#define ll long long
#define ull unsigned long long
#define eps 1e-8
#define MOD 1000000007
#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
#define PI acos(-1)
template<class T>
inline void scan_d(T &ret)
{
    char c;
    int flag = 0;
    ret=0;
    while(((c=getchar())<'0'||c>'9')&&c!='-');
    if(c == '-')
    {
        flag = 1;
        c = getchar();
    }
    while(c>='0'&&c<='9') ret=ret*10+(c-'0'),c=getchar();
    if(flag) ret = -ret;
}
const int maxn = 200000+10;

int s[maxn],t[maxn],t2[maxn],c[maxn],sa[maxn];
int rank[maxn],height[maxn];
void build_sa(int n, int m)
{
    int i, *x = t, *y = t2;
    for(i = 0; i < m; i++) c[i] = 0;
    for(i = 0; i < n; i++) c[x[i] = s[i]]++;
    for(i = 1; i < m; i++) c[i] += c[i-1];
    for(i = n-1; i >= 0; i--) sa[--c[x[i]]] = i;
    for(int k = 1; k <= n; k <<= 1)
    {
        int p = 0;
        for(i = n-k; i < n; i++) y[p++] = i;
        for(i = 0; i < n; i++) if(sa[i] >= k) y[p++] = sa[i] - k;
        for(i = 0; i < m; i++) c[i] = 0;
        for(i = 0; i < n; i++) c[x[i]]++;
        for(i = 1; i < m; i++) c[i] += c[i-1];
        for(i = n-1; i >= 0; i--) sa[--c[x[y[i]]]] = y[i];
        swap(x,y);
        p = 1;x[sa[0]] = 0;
        for(i = 1; i < n; i++)
            x[sa[i]] = y[sa[i-1]]==y[sa[i]] && y[sa[i-1]+k] == y[sa[i]+k] ? p-1 : p++;
        if(p >= n) break;
        m = p;
    }
}

void getHeight(int n)
{
    int i,k = 0;
    for(i = 0; i < n; i++) rank[sa[i]] = i;
    for(i = 0; i < n; i++)
    {
        if(k) k--;
        if(rank[i] == 0) continue;
        int j = sa[rank[i]-1];
        while(s[i+k] == s[j+k]) k++;
        height[rank[i]] = k;
    }
}
char s1[maxn],s2[maxn];
ll dp[maxn];
int st[maxn][3];

int main()
{
#ifdef GLQ
    freopen("input.txt","r",stdin);
//    freopen("o1.txt","w",stdout);
#endif // GLQ
    int k;
    while(~scanf("%d",&k) && k)
    {
        scanf("%s%s",s1,s2);
        int len1 = strlen(s1),len2 = strlen(s2);
        for(int i = 0; i < len1; i++)
            s[i] = s1[i];
        s[len1] = 255;
        for(int i = len1+1; i <= len1+len2; i++)
            s[i] = s2[i-len1-1];
        s[len1+len2+1] = 256;
        int len = len1+len2+2;
        build_sa(len,257);
        getHeight(len);
        int top = 0,cnt;
        ll ans = 0,tmp;
        for(int i = 1; i < len; i++)
        {
            if(height[i] < k)
            {
                top = 0;
                dp[i] = 0;
            }
            else
            {
                cnt = 0;
                if(sa[i-1] > len1)
                    cnt++;
                while(top > 0 && height[i] <= st[top][0])
                {
                    cnt += st[top][1];
                    top--;
                }
                if(top) tmp = dp[st[top][2]];
                else tmp = 0;
                dp[i] = tmp+(ll)cnt*(ll)(height[i]-k+1);
                if(sa[i] < len1)
                {
                    ans += dp[i];
                }
                st[++top][0] = height[i]; st[top][1] = cnt;
                st[top][2] = i;
            }
        }
        top = 0;
        for(int i = 1; i < len; i++)
        {
            if(height[i] < k)
            {
                top = 0;
                dp[i] = 0;
            }
            else
            {
                cnt = 0;
                if(sa[i-1] < len1)
                    cnt++;
                while(top > 0 && height[i] <= st[top][0])
                {
                    cnt += st[top][1];
                    top--;
                }
                if(top) tmp = dp[st[top][2]];
                else tmp = 0;
                dp[i] = tmp+(ll)cnt*(ll)(height[i]-k+1);
                if(sa[i] > len1)
                {
                    ans += dp[i];
                }
                st[++top][0] = height[i]; st[top][1] = cnt;
                st[top][2] = i;
            }
        }
        printf("%lld\n",ans);
    }
    return 0;
}


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章