題目http://www.lydsy.com/JudgeOnline/problem.php?id=3879
題目大意就是說給定一個字符串,給定一個序列,要你求序列中兩兩後綴的最長公共前綴(LCP)的和。
本人較傻,沒學過後綴樹那高端玩意,正好前段時間剛學了後綴數組,然後就試着寫了一下,因爲代碼不好看,跑得比較慢。
這題我的思路是先求出height和rank,然後把讀進來的序列(假設當前位置是i,數字是x)a[i] = rank[x] 意思就是序列中第i個數表示字符串中第x個字符的排名(這個實際意義我不會解釋請各位機智的讀者自己yy一下),然後我們排序,由於題目限制,我們需要去重,然後就是求a數組中對於一個a[i]在height中查找a[i - 1] + 1 到a[i]的最小值(因爲兩兩後綴的LCP就是height[rank[l] + 1] ~height[rank[r]]),然後接下來就是求區間最小值的和,然後我們用單調隊列維護。
附上代碼
#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int MAXX = 500005;
int rank[MAXX], sa[MAXX], x[MAXX], w[MAXX], n, m, i, j, k, M;
char s[MAXX];
int a[MAXX], f[21][MAXX], h, lo[MAXX];
int tail, top[2][MAXX], b[MAXX];
long long ans;
inline int get()
{
char c;
while ((c = getchar()) < 48 || c > 57);
int res = c - 48;
while ((c = getchar()) >= 48 && c <= 57)
res = res * 10 + c - 48;
return res;
}
inline int MIN(const int &x, const int &y)
{
if (x < y) return x;
else return y;
}
int main()
{
n = get(); M = get();
for (i = 0; i <= n; i ++)
lo[i] = (int)log2(i);
scanf("%s", s + 1);
m = 255;
for (i = 1; i <= n; i ++)
w[x[i] = s[i]] ++;
for (i = 2; i <= m; i ++)
w[i] += w[i - 1];
for (i = n; i >= 1; i --)
sa[w[x[i]]--] = i;
for (k = 1; k <= n; k <<= 1)
{
int t = 0;
for (i = n; i >= n - k + 1; i --)
rank[++t] = i;
for (i = 1; i <= n; i ++)
if (sa[i] > k) rank[++t] = sa[i] - k;
for (i = 1; i <= m; i ++)
w[i] = 0;
for (i = 1; i <= n; i ++)
w[x[i]] ++;
for (i = 2; i <= m; i ++)
w[i] += w[i - 1];
for (i = n; i >= 1; i --)
sa[w[x[rank[i]]]--] = rank[i];
m = 0;
for (i = 1; i <= n; i ++)
{
int u = sa[i], v = sa[i - 1];
if (x[u] != x[v] || x[u + k] != x[v + k]) rank[u] = ++m;
else rank[u] = m;
}
for (i = 1; i <= n; i ++)
swap(x[i], rank[i]);
}
h = 0;
for(i = 1; i <= n; i ++)
{
if (h) h --;
int j = sa[rank[i] - 1];
while (s[i + h] == s[j + h]) h ++;
f[0][rank[i]] = h;
}
int mm = (int)log2(n);
for(j = 1; j <= mm; j ++)
for(i = 1; i <= n && (1 << j) + i - 1 <= n; i ++)
f[j][i] = MIN(f[j - 1][i], f[j - 1][i + (1 << (j - 1))]);
while (M--)
{
int len = get();
for(i = 1; i <= len; i ++)
a[i] = rank[get()];
sort(a + 1, a + 1 + len);
len = unique(a + 1, a + 1 + len) - a - 1;
for(i = 1; i <= len; i ++)
{
int l = a[i - 1] + 1, r = a[i];
int jj = lo[r - l + 1];
b[i] = MIN(f[jj][l], f[jj][r - (1 << jj) + 1]);
}
ans = tail = 0;
for(i = 1; i <= len; i ++)
{
while (tail && top[0][tail] > b[i])
ans = ans + (top[1][tail] - top[1][tail - 1]) * (long long)top[0][tail] * (i - top[1][tail]), tail --;
top[0][++tail] = b[i];
top[1][tail] = i;
}
len ++;
for(; tail; tail --)
ans = ans + (top[1][tail] - top[1][tail - 1]) * (long long)top[0][tail] * (len - top[1][tail]);
printf("%lld\n", ans);
}
}