多項式快速插值學習小記

  • 今天終於抽空把這個綜(du)合(liu)知識點學了,心力交瘁……

多項式快速插值

  • 給出 nn 個點 (xi,yi)(x_i,y_i) ,要求一個次數爲 n1n-1 的多項式 F(x)F(x) 滿足:F(xi)=yiF(x_i)=y_i

  • 顯然這個多項式是唯一確定的。

  • 根據拉格朗日插值法,我們有:F(x)=i=1nj̸=i(xxj)j̸=i(xixj)yiF(x)=\sum_{i=1}^{n}\frac{\prod_{j\not=i}(x-x_j)}{\prod_{j\not=i}(x_i-x_j)}y_i

  • 這樣我們是可以 O(n2)O(n^2) 求的,考慮優化。

  • 我們先考慮對於每個 ii ,如何快速得到 j̸=i(xixj)\prod_{j\not=i}(x_i-x_j)

  • M(x)=i=1n(xxi)M(x)=\prod_{i=1}^{n}(x-x_i) ,我們即需求:M(x)xxi\frac{M(x)}{x-x_i}

  • 根據洛必達法則,當 xxxix_i 時,分子分母都等於0,一個0/0型,上下求導得:limxxiM(x)xxi=M(x)lim_{x→x_i}\frac{M(x)}{x-x_i}=M'(x)

  • 於是我們先分治NTT求出 M(x)M(x) ,再求導得到 M(x)M'(x) ,之後將 x1nx_{1-n} 代入多點求值即可!

  • 這一部分的複雜度是 O(n log2n)O(n\ log^2n) 的。

  • 求導、多點求值這些前置知識我的博客裏都有講:

    多項式的求逆、取模和多點求值學習小記

    多項式的ln、exp和快速冪學習小記

  • 又設 Vi=yij̸=i(xixj)V_i=\frac{y_i}{\prod_{j\not=i}(x_i-x_j)} ,則此時 ViV_i 已知,我們要求:F(x)=i=1nVij̸=i(xxj)F(x)=\sum_{i=1}^{n}V_i\prod_{j\not=i}(x-x_j)

  • 還是分治NTT,設 L(x)=i=1n/2(xxi)L(x)=\sum_{i=1}^{n/2}(x-x_i)R(x)=i=n/2+1n(xxi)R(x)=\sum_{i=n/2+1}^{n}(x-x_i) ,則有:F(x)=i=1n/2Vij̸=i,1jn/2(xxj)R(x)+i=n/2+1nVij̸=i,n/2+1jn(xxj)L(x)F(x)=\sum_{i=1}^{n/2}V_i\prod_{j\not=i,1\leq j\leq n/2}(x-x_j)R(x)+\sum_{i=n/2+1}^{n}V_i\prod_{j\not=i,n/2+1\leq j\leq n}(x-x_j)L(x)

  • 遞歸即可求得,遞歸底層的 F(x)F(x) 就是 ViV_i

  • 還有就是這裏的 L(x)R(x)L(x)、R(x) 在多點求值中已經算過了,不用再算一遍啦。

  • 這一部分的複雜度也是 O(n log2n)O(n\ log^2n)

  • 故總時間複雜度即爲 O(n log2n)O(n\ log^2n) ,常數很大很大。

  • 模板題:洛谷 P5158 【模板】多項式快速插值

Code

#include<cstdio>
#include<algorithm>
#include<cctype>
using namespace std;
typedef long long LL;
const int N=1e5+5,M=18,G=3,mo=998244353;
int tot;
int xx[N],yy[N],val[N];
int a[N],b[N],c[N],rr[N];//a=b*c+rr
int ra[N],rb[N<<2],irb[N<<2];
int f[N*M<<1],stf[N<<2],enf[N<<2];
int g[N*M<<1],stg[N<<2],eng[N<<2];
int h[N],sth[N],enh[N],f3[N];
int f1[N<<2],f2[N<<2],wn[N<<2],rev[N<<2];
inline int read()
{
	int X=0,w=0; char ch=0;
	while(!isdigit(ch)) w|=ch=='-',ch=getchar();
	while(isdigit(ch)) X=(X<<3)+(X<<1)+(ch^48),ch=getchar();
	return w?-X:X;
}
void write(int x)
{
	if(x>9) write(x/10);
	putchar(x%10+'0');
}
inline int ksm(int x,int y)
{
	int s=1;
	while(y)
	{
		if(y&1) s=(LL)s*x%mo;
		x=(LL)x*x%mo;
		y>>=1;
	}
	return s;
}
inline void NTT(int *y,int len,int ff)
{
	for(int i=0;i<len;i++)
		if(i<rev[i]) swap(y[i],y[rev[i]]);
	for(int h=2,d=len>>1;h<=len;h<<=1,d>>=1)
		for(int i=0,k=h>>1;i<len;i+=h)
			for(int j=0,cnt=0;j<k;j++,cnt+=d)
			{
				int u=y[i+j],t=(LL)wn[cnt]*y[i+j+k]%mo;
				y[i+j]=u+t>=mo?u+t-mo:u+t;
				y[i+j+k]=u-t<0?u-t+mo:u-t;
			}
	if(ff==-1)
	{
		reverse(y+1,y+len);
		int inv=ksm(len,mo-2);
		for(int i=0;i<len;i++) y[i]=(LL)y[i]*inv%mo;
	}
}
void make(int v,int l,int r)
{
	if(l==r)
	{
		g[stg[v]=++tot]=mo-xx[l];
		g[eng[v]=++tot]=1;
		return;
	}
	int mid=l+r>>1,ls=v<<1,rs=ls|1;
	make(ls,l,mid),make(rs,mid+1,r);
	int na=eng[ls]-stg[ls]+1,nb=eng[rs]-stg[rs]+1;
	int len=1,ll=0;
	while(len<na+nb) len<<=1,ll++;
	for(int i=0;i<len;i++) rev[i]=rev[i>>1]>>1|(i&1)<<ll-1;
	int w0=ksm(G,(mo-1)/len);
	for(int i=wn[0]=1;i<=len;i++) wn[i]=(LL)wn[i-1]*w0%mo;
	for(int i=0;i<na;i++) f1[i]=g[stg[ls]+i];
	for(int i=na;i<len;i++) f1[i]=0;
	for(int i=0;i<nb;i++) f2[i]=g[stg[rs]+i];
	for(int i=nb;i<len;i++) f2[i]=0;
	NTT(f1,len,1),NTT(f2,len,1);
	for(int i=0;i<len;i++) f1[i]=(LL)f1[i]*f2[i]%mo;
	NTT(f1,len,-1);
	stg[v]=tot+1;
	na+=nb-1;
	for(int i=0;i<na;i++) g[++tot]=f1[i];
	eng[v]=tot;
}
void getinv(int len,int ll)
{
	if(len==1)
	{
		irb[0]=ksm(rb[0],mo-2);
		return;
	}
	getinv(len>>1,ll-1);
	for(int i=0;i<len;i++) rev[i]=rev[i>>1]>>1|(i&1)<<ll-1;
	int w0=ksm(G,(mo-1)/len);
	for(int i=wn[0]=1;i<=len;i++) wn[i]=(LL)wn[i-1]*w0%mo;
	for(int i=0;i<len>>1;i++) f1[i]=rb[i];
	for(int i=len>>1;i<len;i++) f1[i]=0;
	NTT(f1,len,1),NTT(irb,len,1);
	for(int i=0;i<len;i++) irb[i]=(2-(LL)f1[i]*irb[i]%mo+mo)*irb[i]%mo;
	NTT(irb,len,-1);
	for(int i=len>>1;i<len;i++) irb[i]=0;
}
void solve(int v,int l,int r,int fa)
{
	int na=enf[fa]-stf[fa],nb=eng[v]-stg[v];
	if(na>=nb)
	{
		int nc=na-nb;
		for(int i=0;i<=na;i++) a[i]=f[stf[fa]+i];
		for(int i=0;i<=nb;i++) b[i]=g[stg[v]+i];
		for(int i=0;i<=nc;i++) ra[i]=a[na-i];
		for(int i=0;i<=nb;i++) rb[i]=b[nb-i];
		for(int i=nc+1;i<=nb;i++) rb[i]=0;
		int len=1,ll=0;
		while(len<=nc*2+1) len<<=1,ll++;
		for(int i=nb+1;i<len;i++) rb[i]=0;
		for(int i=0;i<len;i++) irb[i]=0,f1[i]=0;
		getinv(len,ll);
		for(int i=0;i<=nc;i++) f1[i]=ra[i],f2[i]=irb[i];
		for(int i=nc+1;i<len;i++) f1[i]=f2[i]=0;
		NTT(f1,len,1),NTT(f2,len,1);
		for(int i=0;i<len;i++) f1[i]=(LL)f1[i]*f2[i]%mo;
		NTT(f1,len,-1);
		for(int i=0;i<=nc;i++) c[nc-i]=f1[i];
		for(int i=nc+1;i<nb;i++) c[i]=0;
		len=1,ll=0;
		while(len<nb<<1) len<<=1,ll++;
		for(int i=0;i<len;i++) rev[i]=rev[i>>1]>>1|(i&1)<<ll-1;
		int w0=ksm(G,(mo-1)/len);
		for(int i=wn[0]=1;i<=len;i++) wn[i]=(LL)wn[i-1]*w0%mo;
		for(int i=0;i<nb;i++) f1[i]=b[i],f2[i]=c[i];
		for(int i=nb;i<len;i++) f1[i]=0,f2[i]=0;
		NTT(f1,len,1),NTT(f2,len,1);
		for(int i=0;i<len;i++) f1[i]=(LL)f1[i]*f2[i]%mo;
		NTT(f1,len,-1);
		for(int i=0;i<nb;i++) rr[i]=(a[i]-f1[i]+mo)%mo;
		while(nb>1 && !rr[nb-1]) nb--;
		stf[v]=tot+1;
		for(int i=0;i<nb;i++) f[++tot]=rr[i];
		enf[v]=tot;
	}else
	{
		stf[v]=tot+1;
		for(int i=stf[fa];i<=enf[fa];i++) f[++tot]=f[i];
		enf[v]=tot;
	}
	if(l==r)
	{
		val[l]=f[stf[v]];
		return;
	}
	int mid=l+r>>1;
	solve(v<<1,l,mid,v);
	solve(v<<1|1,mid+1,r,v);
}
void work(int v,int l,int r)
{
	if(l==r) return;
	int mid=l+r>>1,ls=v<<1,rs=ls|1;
	work(ls,l,mid);
	work(rs,mid+1,r);
	int na=enh[l]-sth[l]+1,nb=eng[rs]-stg[rs]+1;
	int len=1,ll=0;
	while(len<na+nb) len<<=1,ll++;
	for(int i=0;i<len;i++) rev[i]=rev[i>>1]>>1|(i&1)<<ll-1;
	int w0=ksm(G,(mo-1)/len);
	for(int i=wn[0]=1;i<=len;i++) wn[i]=(LL)wn[i-1]*w0%mo;
	for(int i=0;i<na;i++) f1[i]=h[sth[l]+i];
	for(int i=na;i<len;i++) f1[i]=0;
	for(int i=0;i<nb;i++) f2[i]=g[stg[rs]+i];
	for(int i=nb;i<len;i++) f2[i]=0;
	NTT(f1,len,1),NTT(f2,len,1);
	for(int i=0;i<len;i++) f1[i]=(LL)f1[i]*f2[i]%mo;
	NTT(f1,len,-1);
	na+=nb-1;
	for(int i=0;i<na;i++) f3[i]=f1[i];
	for(int i=na;i<len;i++) f3[i]=0;
	
	na=enh[mid+1]-sth[mid+1]+1,nb=eng[ls]-stg[ls]+1;
	for(int i=0;i<na;i++) f1[i]=h[sth[mid+1]+i];
	for(int i=na;i<len;i++) f1[i]=0;
	for(int i=0;i<nb;i++) f2[i]=g[stg[ls]+i];
	for(int i=nb;i<len;i++) f2[i]=0;
	NTT(f1,len,1),NTT(f2,len,1);
	for(int i=0;i<len;i++) f1[i]=(LL)f1[i]*f2[i]%mo;
	NTT(f1,len,-1);
	na+=nb-1;
	for(int i=0;i<na;i++) h[sth[l]+i]=f3[i]+f1[i]>=mo?f3[i]+f1[i]-mo:f3[i]+f1[i];
	enh[l]=sth[l]+na-1;
}
int main()
{
	int n=read();
	for(int i=1;i<=n;i++) xx[i]=read(),yy[i]=read();
	make(1,1,n);
	int m=eng[1]-stg[1];
	for(int i=0;i<=m;i++) f1[i]=g[stg[1]+i];
	for(int i=0;i<m;i++) f1[i]=(LL)f1[i+1]*(i+1)%mo;
	f1[m--]=0;
	stf[tot=0]=1;
	for(int i=0;i<=m;i++) f[enf[0]=++tot]=f1[i];
	solve(1,1,n,0);
	for(int i=1;i<=n;i++) val[i]=(LL)yy[i]*ksm(val[i],mo-2)%mo;
	tot=0;
	for(int i=1;i<=n;i++)
	{
		sth[i]=enh[i]=++tot;
		h[tot]=val[i];
	}
	work(1,1,n);
	for(int i=0;i<n;i++) write(h[sth[1]+i]),putchar(' ');
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章