[Comet OJ - Contest #6 E]字符串

Description

给出一个长度为n的字符串S,定义f(S)为S的所有n(n+1)/2n*(n+1)/2个子串,两两求LCP的和
对于每个i,求出f(S[i…n]),答案对998244353取模
n<=200000

Solution

log^2的做法有很多这里就不一一说了
数据结构学傻了.jpg
先考虑两个后缀l和r的所有前缀互相匹配的答案,显然只和后缀长度和LCP有关
设LCP为x,后缀长度为L和R,那么答案为G(l,r,x)=i=1Lj=1Rmin(i,j,x)G(l,r,x)=\sum_{i=1}^{L}\sum_{j=1}^{R}min(i,j,x)
稍微化一下式子,设A=2x33x2+x6,B=xx22,C=xA={2x^3-3x^2+x\over 6},B={x-x^2\over 2},C=x,那么G(l,r,x)=A+B(L+R)+CLRG(l,r,x)=A+B(L+R)+CLR
考虑在后缀树上,每次合并两个后缀集合,那么这两个后缀集合的LCP就是一个常数,也就是A,B,C都是常数
由于要对每个后缀都求答案,只需要把每一对后缀的贡献挂在较小的那个后缀上即可
显然可以发现,对于一个后缀x的答案f(x)f(x)可以写成f(x)=k(nx+1)+bf(x)=k(n-x+1)+b这样的一次函数的形式,于是我们只需要对每个位置维护一次函数的系数即可
对于一个后缀L,如果插入了一个后缀R,那么增量为A+B(L+R)+CLR=L(B+CR)+A+BRA+B(L+R)+CLR=L(B+CR)+A+BR,只和后缀R的长度以及个数有关
考虑用线段树合并解决,每个区间维护区间内后缀的长度和及个数,在合并的时候将用右区间的信息去贡献左区间的一次函数即可做到O(n log n)

Code

#include <set>
#include <vector>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define pb(a) push_back(a)
using namespace std;

typedef long long ll;
typedef set<int> :: iterator it;

const int N=4e5+5,M=1e7+5,Mo=998244353;

int n,pos[N],p[N];
ll an[N];

namespace SAM{
	int son[N][26],len[N],pr[N],tot,lst;
	char s[N];

	int extend(int p,int x) {
		int np=++tot;len[np]=len[p]+1;
		for(;p&&!son[p][x];p=pr[p]) son[p][x]=np;
		if (!p) pr[np]=1;
		else {
			int q=son[p][x];
			if (len[q]==len[p]+1) pr[np]=q;
			else {
				int nq=++tot;
				fo(i,0,25) son[nq][i]=son[q][i];
				pr[nq]=pr[q];len[nq]=len[p]+1;
				pr[q]=pr[np]=nq;
				for(;p&&son[p][x]==q;p=pr[p]) son[p][x]=nq;
			}
		}
		return np;
	}

	void init() {
		scanf("%s",s+1);n=strlen(s+1);
		lst=tot=1;
		fd(i,n,1) pos[i]=lst=extend(lst,s[i]-'a');
	}
}

int cnt[M],ls[M],rs[M],rt[N],tot;
ll sum[M],tgk[M],tgb[M];

void insert(int &v,int l,int r,int x) {
	if (!v) v=++tot;
	cnt[v]++;sum[v]+=n-x+1;
	if (l==r) return;
	int mid=l+r>>1;
	if (x<=mid) insert(ls[v],l,mid,x);
	else insert(rs[v],mid+1,r,x);
}

void down(int v) {
	if (tgk[v]) {
		if (ls[v]) tgk[ls[v]]+=tgk[v];
		if (rs[v]) tgk[rs[v]]+=tgk[v];
		tgk[v]=0;
	}
	if (tgb[v]) {
		if (ls[v]) tgb[ls[v]]+=tgb[v];
		if (rs[v]) tgb[rs[v]]+=tgb[v];
		tgb[v]=0;
	}
}

void upd(int x,int y,ll L) {
	ll A=((L*L*L*2+L*L*3+L)/6-L*L)%Mo;
	ll B=((L*L+L)/2-L*L)%Mo,C=L;
	(tgk[x]+=B*cnt[y]+C*sum[y])%=Mo;
	(tgb[x]+=A*cnt[y]+B*sum[y])%=Mo;
} 

int merge(int x,int y,int L) {
	if (!x||!y) return x+y;
	down(x);down(y);
	upd(ls[x],rs[y],L);
	upd(ls[y],rs[x],L);
	ls[x]=merge(ls[x],ls[y],L);
	rs[x]=merge(rs[x],rs[y],L);
	cnt[x]=cnt[ls[x]]+cnt[rs[x]];
	sum[x]=sum[ls[x]]+sum[rs[x]];
	return x;
}

vector<int> son[N];

void dfs(int x) {
	if (p[x]) insert(rt[x],1,n,p[x]);
	for(int y:son[x]) {
		dfs(y);
		rt[x]=merge(rt[x],rt[y],SAM::len[x]);
	}
}

ll calc(ll l) {
	ll A=(l*l*l*2-l*l*3+l)/6%Mo;
	ll B=(-l*l+l)%Mo,C=l;
	return (A+B*l+C*l*l)%Mo;
}

void get_ans(int v,int l,int r) {
	if (l==r) {
		an[l]=(tgk[v]*(n-l+1)+tgb[v])%Mo;
		an[l]=(an[l]*2+calc(n-l+1))%Mo;
		return;
	}
	int mid=l+r>>1;down(v);
	get_ans(ls[v],l,mid);get_ans(rs[v],mid+1,r);
}

int main() {
	//freopen("e.in","r",stdin);
	//freopen("e.out","w",stdout);
	SAM::init();
	fo(i,2,SAM::tot) son[SAM::pr[i]].pb(i);
	fo(i,1,n) p[pos[i]]=i;
	dfs(1);get_ans(rt[1],1,n);
	fd(i,n,1) (an[i]+=an[i+1])%=Mo;
	fo(i,1,n) printf("%lld ",(an[i]+Mo)%Mo);
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章