序列方差[NTT]

也許更好的閱讀體驗

Description\mathcal{Description}

給你一個長度爲nn的數組aa
你會得到 qq 條指令, 分兩種:

  • 1 l r w1\ l\ r\ w 表示把 l,l+1,rl, l + 1,\ldots r 這段區間的每一個數 +w+w.
  • 2 l r2\ l\ r 表示詢問 l,l+1,rl, l + 1, \ldots r 這段區間每個子序列方差之和.

對於每個詢問輸出一行表示答案
答案對998244353998244353取模
n,q1050ai,w<998244353n,q\leq 10^5\\ 0\leq a_i,w<998244353

Solution\mathcal{Solution}

方差的定義
對一個序列a1,a2,,ana_1,a_2,\ldots,a_n,設其平均數爲a=ain\overline a=\dfrac{\sum\limits a_i}{n},方差s2=(aia)2ns^2=\dfrac{\sum\limits \left(a_i-\overline a\right)^2}{n}

前置知識
(a1+a2++an)2=a12+a22++an2+2i!=jaiaj\left(a_1+a_2+\ldots+a_n\right)^2=a_1^2+a_2^2+\ldots+a_n^2+2\sum\limits_{i!=j} a_ia_j

s2=(aia)2n=(ai22aia+a2)n=ai22na2+nan=ai2na2n=ai2ai2+2i!=jaiajnns^2=\dfrac{\sum\limits \left(a_i-\overline a\right)^2}{n}=\dfrac{\sum \left(a_i^2-2a_i\overline a+\overline a^2\right)}{n}=\dfrac{\sum a_i^2-2n\overline a^2+n\overline a}{n}=\dfrac{\sum a_i^2-n\overline a^2}{n}=\dfrac{\sum a_i^2-\dfrac{\sum a_i^2+2\sum\limits_{i!=j}a_ia_j}{n}}{n}
把式子提出來重寫一遍

s2=ai2ai2+2i!=jaiajnns^2=\dfrac{\sum a_i^2-\dfrac{\sum a_i^2+2\sum\limits_{i!=j}a_ia_j}{n}}{n}
於是可以計算上面的ai,aiaja_i,a_ia_j出現次數(貢獻的係數),考慮一個長度爲nn的區間,其子序列長度可以是len=1,2,,nlen=1,2,\ldots,n
考慮aia_i一定被選,剩下n1n-1個數中要選len1len-1個數出來,類似這樣考慮即可
顯然對於每個區間內的數,其係數是一樣的
fnf_n表示長度爲nn的區間裏的ai2a_i^2的係數
則有
fn=i=1n(n1i1)11ii=i=1n(n1i1)i1i2=i=1n(n1)! (i1)(ni)!(i1)! i2=(n1)!i=1n1(ni)!(i2)!i2=(n1)!i+j=nn1(i2)!i21j!\begin{aligned}f_n&=\sum\limits_{i=1}^n\begin{pmatrix}n-1\\i-1\end{pmatrix}\dfrac{1-\dfrac{1}{i}}{i}\\ &=\sum\limits_{i=1}^n\begin{pmatrix}n-1\\i-1\end{pmatrix}\dfrac{i-1}{i^2}\\&= \sum\limits_{i=1}^n\dfrac{\left(n-1\right)!\ \left(i-1\right)}{\left(n-i\right)!\left(i-1\right)!\ i^2}\\&=\left(n-1\right)!\sum\limits_{i=1}^n\dfrac{1}{\left(n-i\right)!\left(i-2\right)!i^2}\\&=\left(n-1\right)!\sum\limits_{i+j=n}^n\dfrac{1}{\left(i-2\right)!i^2}\dfrac{1}{j!}\end{aligned}
注意最後i2i\geq 2
可以發現這是一個卷積的形式,可以nlognnlogn算出來

gng_n表示長度爲nn的區間裏的2aiaj2a_ia_j的係數
則有
gn=i=2n(n2i2)1i2g_n=\sum\limits_{i=2}^n\begin{pmatrix}n-2\\i-2\end{pmatrix}\dfrac{-1}{i^2}
同理可推得
gn=(n2)!i+j=nn1(i2)!i21j!g_n=-\left(n-2\right)!\sum\limits_{i+j=n}^n\dfrac{1}{\left(i-2\right)!i^2}\dfrac{1}{j!}
注意i2i\geq 2
fn,gnf_n,g_n只有前面的係數不同,把後面弄出來即可

另外,對於一個區間,用線段樹維護維護一下ai\sum a_iai2\sum a_i^2
Code\mathcal{Code}

/*******************************
Author:Morning_Glory
LANG:C++
Created Time:2019年10月03日 星期四 09時36分57秒
*******************************/
#include <cstdio>
#include <fstream>
using namespace std;
const int maxn = 1000006;
const int maxt = 1000006;
const int mod = 998244353;
const int gn = 3;
//{{{cin
struct IO{
	template<typename T>
	IO & operator>>(T&res){
		res=0;
		bool flag=false;
		char ch;
		while((ch=getchar())>'9'||ch<'0')	flag|=ch=='-';
		while(ch>='0'&&ch<='9')	res=(res<<1)+(res<<3)+(ch^'0'),ch=getchar();
		if (flag)	res=~res+1;
		return *this;
	}
}cin;
//}}}
inline int add (int &x,int y){	return x=((x+y)%mod+mod)%mod;}
inline int mul (int x,int y){	return 1ll*x*y%mod;}
int n,q,id,ans;
int a[maxn],rev[maxn],inv[maxn],fac[maxn],ifac[maxn];
int f[maxn],g[maxn];
//{{{SegmentTree
struct SegmentTree{
	//0 -> x^2
	//1 -> x
	//{{{defination
	#define cl (k<<1)
	#define cr (k<<1|1)
	#define lm (lt[k]+rt[k])/2
	#define rm (lt[k]+rt[k])/2+1
	#define val0(x) val[x][0]
	#define val1(x) val[x][1]
	#define lazy0(x) lazy[x][0]
	#define lazy1(x) lazy[x][1]
	#define len(x) (rt[x]-lt[x]+1)
	int lt[maxt],rt[maxt],val[maxt][2],lazy[maxt][2];
	//}}}
	//{{{build
	void build (int l,int r,int k=1)
	{
		lt[k]=l,rt[k]=r;
		if (l==r)	return val0(k)=mul(a[l],a[l]),val1(k)=a[l],void();
		build(l,lm,cl);
		build(rm,r,cr);
		val0(k)=(val0(cl)+val0(cr))%mod;
		val1(k)=(val1(cl)+val1(cr))%mod;
	}
	//}}}
	//{{{pushdowna
	void pushdowna (const int &k)
	{
		add(lazy1(cl),lazy1(k));
		add(lazy1(cr),lazy1(k));
		add(val1(cl),mul(len(cl),lazy1(k)));
		add(val1(cr),mul(len(cr),lazy1(k)));
		lazy1(k)=0;
	}
	//}}}
	//{{{pushdowns
	void pushdowns (const int &k)
	{
		add(lazy0(cl),lazy0(k));
		add(lazy0(cr),lazy0(k));
		add(val0(cl),mul(len(cl),mul(lazy0(k),lazy0(k))));
		add(val0(cr),mul(len(cr),mul(lazy0(k),lazy0(k))));
		add(val0(cl),mul(mul(lazy0(k),2),val1(cl)));
		add(val0(cr),mul(mul(lazy0(k),2),val1(cr)));
		lazy0(k)=0;
	}
	//}}}
	//{{{modifys
	void modifys (const int l,const int r,const int v,const int &k=1)
	{
		if (lt[k]>=l&&rt[k]<=r){
			add(val0(k),mul(len(k),mul(v,v)));
			add(val0(k),mul(mul(v,2),val1(k)));
			add(lazy0(k),v);
			return;
		}
		if (lazy0(k))	pushdowns(k);
		if (lazy1(k))	pushdowna(k);
		if (lm>=l)	modifys(l,r,v,cl);
		if (rm<=r)	modifys(l,r,v,cr);
		val0(k)=(val0(cl)+val0(cr))%mod;
		val1(k)=(val1(cl)+val1(cr))%mod;
	}
	//}}}
	//{{{modifya
	void modifya (const int &l,const int &r,const int &v,const int &k=1)
	{
		if (lt[k]>=l&&rt[k]<=r){
			add(val1(k),mul(len(k),v));
			add(lazy1(k),v);
			return;
		}
		if (lazy0(k))	pushdowns(k);
		if (lazy1(k))	pushdowna(k);
		if (lm>=l)	modifya(l,r,v,cl);
		if (rm<=r)	modifya(l,r,v,cr);
		val0(k)=(val0(cl)+val0(cr))%mod;
		val1(k)=(val1(cl)+val1(cr))%mod;
	}
	//}}}
	//{{{query
	int query (const int l,const int r,const bool opt,const int &k=1)
	{
		if (lt[k]>=l&&rt[k]<=r)	return val[k][opt];
		if (lazy0(k))	pushdowns(k);
		if (lazy1(k))	pushdowna(k);
		int res=0;
		if (lm>=l)	add(res,query(l,r,opt,cl));
		if (rm<=r)	add(res,query(l,r,opt,cr));
		return res;
	}
	//}}}
}ST;
//}}}
//{{{ksm
int ksm (int a,int b)
{
	int s=1;
	for (;b;b>>=1,a=1ll*a*a%mod)
		if (b&1)	s=1ll*s*a%mod;
	return s;
}
//}}}
//{{{get_rev
int get_rev (int len)//the maximum power of x is len!!!! not the length
{
	int lim=1,bit=0;
	while (lim<=len)	lim<<=1,++bit;
	for (int i=0;i<lim;++i)	rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
	return lim;
}
//}}}
//{{{NTT
void NTT (int *a,int len,int type)
{
	for (int i=0;i<=len-1;++i)
		if (i<rev[i])	swap(a[i],a[rev[i]]);
	for (int i=1;i<len;i<<=1){
		int wn=ksm(gn,(mod-1)/(i*2));
		if (type==-1)	wn=ksm(wn,mod-2);
		for (int j=0;j<len;j+=(i<<1)){
			int w=1;
			for (int k=0;k<i;++k){
				int u=a[j+k],t=1ll*w*a[j+k+i]%mod;
				a[j+k]=1ll*(u+t)%mod;
				a[j+k+i]=1ll*(u-t+mod)%mod;
				w=1ll*w*wn%mod;
			}
		}
	}
	if (type==-1){
		int inv=ksm(len,mod-2);
		for (int i=0;i<=len-1;++i)	a[i]=1ll*a[i]*inv%mod;
	}
}
//}}}
//{{{init
void init ()
{
	fac[0]=ifac[0]=inv[1]=1;
	for (int i=2;i<=n+2;++i)	inv[i]=(-1ll*mod/i*inv[mod%i]%mod+mod)%mod;
	for (int i=1;i<=n+2;++i){
		fac[i]=mul(fac[i-1],i);
		ifac[i]=mul(ifac[i-1],inv[i]);
	}
	for (int i=0;i<=n;++i){
		f[i]=ifac[i];
		g[i]=i<=1?0:mul(mul(inv[i],inv[i]),ifac[i-2]);
	}
	int len=get_rev(n<<1);
	NTT(f,len,1),NTT(g,len,1);
	for (int i=0;i<len;++i)	f[i]=mul(f[i],g[i]);
	NTT(f,len,-1);
	for (int i=0;i<=n;++i){
		g[i]=0;
		if (i>1)	add(g[i],-mul(fac[i-2],f[i])),f[i]=mul(fac[i-1],f[i]);
	}
}
//}}}
int main ()
{
	cin>>n>>q>>id;
	for (int i=1;i<=n;++i)	cin>>a[i];
	init();
	ST.build(1,n);
	while (q--){
		int opt,l,r;
		cin>>opt>>l>>r;
		if (opt==1){
			int x;
			cin>>x;
			ST.modifys(l,r,x);
			ST.modifya(l,r,x);
		}
		else{
			int s1=ST.query(l,r,0),s2=ST.query(l,r,1),len=r-l+1;
			s2=mul(s2,s2);
			add(s2,-s1);
			int ans1=mul(s1,f[len]),ans2=mul(s2,g[len]);
			add(ans1,ans2);
			printf("%d\n",ans1);
		}
	}
	return 0;
}

如有哪裏講得不是很明白或是有錯誤,歡迎指正
如您喜歡的話不妨點個贊收藏一下吧

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章