洛谷P1117 [NOI2016]優秀的拆分(巧妙的計數方法)

文章目錄

題目

[NOI2016]優秀的拆分

分析

統計以SiS_i開頭的形如AA\text{AA}的子串的數量,存入L[i]L[i];統計以SiS_i結尾的形如AA\text{AA}的子串,存入R[i]R[i]。於是把可以把它們拼起來,答案就是i=2n(L[i]×R[i1])\sum \limits_{i = 2}^{n} (L[i] \times R[i - 1])
LLRR數組的處理,暴力枚舉+哈希判斷相等是O(n2)O(n^2)的,考慮優化這個東西。

我們枚舉A\text{A}的長度ll,那麼一個A\text{A}SS中會經過且僅經過一個SklS_{k \cdot l},如圖所示,藍點是SklS_{k \cdot l},可見任何一個長度爲ll的子串必然經過一個藍點。
圖一
那麼我們把這個串看成左右兩端,即以SklS_{k \cdot l}開始的後綴(下圖中橙色示意的範圍)和以SklS_{k \cdot l}開始的前綴(下圖中綠色示意的範圍),這兩個前後綴在SklS_{k \cdot l}處重合。
圖二
不妨假設這個串是某個AA\text{AA}的子串的前一個A\text{A},那麼它後面緊接着一個跟它一模一樣的:
圖三
即橙色(兩個後綴)和綠色(兩個前綴)分別相等。
發現了,我們只需要找到以SklS_{k \cdot l}S(k+1)lS_{(k + 1) \cdot l}結尾的最長公共後綴(LCS),和以SklS_{k \cdot l}S(k+1)lS_{(k + 1) \cdot l}開頭的最長公共前綴(LCP),這兩個二分+哈希O(nlogn)O(n \log n)即可找到。

找到了過後,看下圖(下圖的l=7l = 7,且只是截取了SS中的一部分),假設橙色標記的是LCS,綠色標記的是LCP,那麼紅色標記三對子串都是形如AA\text{AA}的:
圖四
這個時候我們就左邊的三個端點(灰色)的R[i]R[i]全部加一,右邊的三個端點(灰色)的L[i]L[i]全部加一即可,只有區間加法,差分一下即可 (當然線段樹也可以)

給不明白差分的小夥伴;
L[i]=L[i]L[i1]L'[i] = L[i] - L[i - 1],那麼我們對L[i]L'[i]進行操作,最後可以通過L[i]=j=1iL[j]L[i] = \sum \limits_{j = 1}^{i} L'[j],來還原L[i]L[i]
發現L[i]L[i]其實是L[i]L'[i]的前綴和數組,那麼L[i]L[i]的區間加法([l,r][l, r]上加dd),在L[i]L'[i]上只用改兩個點:L[l]+=dL'[l] += dL[r+1]=dL'[r + 1] -= d,這樣一來,想想算前綴和的過程,[l,r][l, r]這一段全部都多了dd

總時間複雜度O((n1+n2++nn)logn)=O(nlog2n)O\left(\left(\dfrac{n}{1}+\dfrac{n}{2}+\cdots+\dfrac{n}{n}\right)\log n\right)=O(n\log^2 n)(用SAM/SA可以少個log?)

代碼

#include <algorithm>
#include <cstdio>
#include <cstring>
#include <queue>

int Read() {
    int x = 0; bool f = false; char c = getchar();
    while (c < '0' || c > '9')
        f |= c =='-', c = getchar();
    while (c >= '0' && c <= '9')
        x = x * 10 + (c ^ 48), c = getchar();
    return f ? -x : x;
}

typedef long long LL;

const int MAXN = 30000;
const int PRIME = 233;
const int MOD = 1000000009;

int N;
char S[MAXN + 5];
int L[MAXN + 5], R[MAXN + 5];

int Hash[MAXN + 5], Pow[MAXN + 5];

int Key(int lft, int rgt) {
    return (Hash[rgt] - (LL)Hash[lft - 1] * Pow[rgt - lft + 1] % MOD + MOD) % MOD;
}

int GetLCS(int i, int j) {
    int lft = 0, rgt = std::min(j - i, i) + 1; // 注意上界不要超了,否則會訪問到不該訪問的地方
    while (lft + 1 < rgt) {
        int mid = (lft + rgt) >> 1;
        if (Key(i - mid + 1, i) == Key(j - mid + 1, j))
            lft = mid;
        else
            rgt = mid;
    }
    return lft;
}

int GetLCP(int i, int j) {
    int lft = 0, rgt = std::min(j - i, N - j + 1) + 1; // 這裏也是
    while (lft + 1 < rgt) {
        int mid = (lft + rgt) >> 1;
        if (Key(i, i + mid - 1) == Key(j, j + mid - 1))
            lft = mid;
        else
            rgt = mid;
    }
    return lft;
}

int main() {
    Pow[0] = 1;
    int T = Read();
    while (T--) {
        scanf("%s", S + 1);
        N = strlen(S + 1);
        for (int i = 1; i <= N; i++) {
            L[i] = R[i] = 0;
            Pow[i] = (LL)Pow[i - 1] * PRIME % MOD;
            Hash[i] = ((LL)Hash[i - 1] * PRIME + (S[i] - 'a')) % MOD;
        }
        for (int len = 1; 2 * len <= N; len++) {
            for (int i = 1; i + len <= N; i += len) {
                int lcs = GetLCS(i, i + len), lcp = GetLCP(i, i + len);
                if (lcs + lcp - 1 >= len) {
                    L[i - lcs + 1]++, L[i + lcp - len + 1]--;
                    R[i - lcs + 2 * len]++, R[i + lcp + len]--; // 這四個點自己參照圖算一下就能找到
                }
            }
        }
        for (int i = 1; i <= N; i++)
            L[i] += L[i - 1], R[i] += R[i - 1]; // 由差分數組還原
        long long Ans = 0;
        for (int i = 2; i <= N; i++)
            Ans += (long long)L[i] * R[i - 1];
        printf("%lld\n", Ans);
    }
}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章