LOJ #150. 挑戰多項式/多項式全家桶

題目

這個題堪稱 簡易模板全家桶

關於各個函數的實現,大多數都是利用牛頓迭代公式(x=x0f(x)f(x)x=x_0-\dfrac {f(x)}{f'(x)})+倍增。

以下爲多項式全家桶證明:

下面的f,b,af,b,a分別爲已知多項式,當前所求多項式和上一個狀態(規模小一半)的多項式.(簡記y=a(x)y=a(x))

sqrt\text{sqrt}

b(x)2=f(x)  >b(x)2f(x)=0b(x)^2=f(x) ~~->b(x)^2-f(x)=0
φ(b(x))=b(x)2f(x)設\varphi(b(x))=b(x)^2-f(x)
φ(b(x))=2b(x)則\varphi'(b(x))=2*b(x)
,此時,我們求函數零點
b(x)=a(x)φ(a(x))φ(a(x))=a(x)a2(x)f(x)2a(x)=a2(x)+f(x)2a(x)b(x)=a(x)-\dfrac{\varphi(a(x))}{\varphi'(a(x))}=a(x)-\dfrac{a^2(x)-f(x)}{2a(x)}=\dfrac{a^2(x)+f(x)}{2a(x)}

y2+f2y\dfrac {y^2+f}{2y}

inv\text{inv}

φ(b(x))=b(x)f(x)1\varphi(b(x))=b(x)f(x)-1

φ(b(x))=f(x)\varphi'(b(x))=f(x)

b(x)=a(x)a(x)f(x)1f(x)b(x)=a(x)-\dfrac{a(x)f(x)-1}{f(x)}

b(x)=a(x)b(x)(a(x)f(x)1)b(x)=a(x)-b(x)(a(x)f(x)-1)

由於a(x)f(x)10(mod  xn/2)a(x)f(x)-1\equiv 0(\mod x^{n/2}).

所以b(x)b(x)的高n/2n/2位乘上a(x)f(x)1a(x)f(x)-1mod  xn\mod x^n意義下爲0.

所以b(x)b(x)a(x)a(x)在當前等價.

b(x)=a(x)a(x)(a(x)f(x)1)=a(x)(2a(x)f(x))b(x)=a(x)-a(x)(a(x)f(x)-1)=a(x)(2-a(x)f(x))

y(2yf)y(2-yf)

ln\ln

lnx=1x\ln 'x=\dfrac 1 x

b(x)=lnf(x)b(x)=\ln f(x)

對兩邊求導:b(x)=inv(f(x))f(x)b'(x)=inv(f(x))f'(x).(鏈式反應)

積分b(x)b'(x)即可得到b(x)b(x).

ff\int \dfrac {f'}f

exp\exp

因爲exp,ln\exp,\ln爲逆運算,所以可得:

b(x)=expf(x)lnb(x)=f(x)b(x)=\exp f(x)\rightarrow \ln b(x)=f(x)

φ(b(x))=lnb(x)f(x)\varphi(b(x))=\ln b(x)-f(x)

φ(b(x))=1b(x)\varphi'(b(x))=\dfrac 1{b(x)}

則可得到:b(x)=a(x)(1lna(x)+f(x))b(x)=a(x)(1-\ln a(x)+f(x)).

y(1lny+f)y(1-\ln y+f)

pow\text{pow}

f(x)k=eln(f(x))kf(x)^k=e^{ln(f(x))*k}

f(x)/g(x),f(x)mod g(x)\text{f(x)/g(x),f(x)mod g(x)}

多項式帶餘除法.

f(x)=q(x)g(x)+r(x)(mod  xn)f(x)=q(x)g(x)+r(x)(\mod x^n).

已知f,gf,gq,rq,r.ffnn次多項式,ggmm次多項式,rr的次數<m<m.

可以發現xnf(x1)x^n f(x^{-1})的係數爲ff係數的翻轉,我們簡記這個多項式爲F(x)F(x).

則有:F(x)=xnq(x)g(x)+xnr(x)=xnmq(x)xmg(x)+xnm+1xm1r(x)=Q(x)G(x)+xnm+1R(x)(mod  xn)F(x)=x^n q(x)g(x)+x^nr(x)=x^{n-m} q(x) x^m g(x)+x^{n-m+1}x^{m-1}r(x)= Q(x)G(x)+x^{n-m+1}R(x)(\mod x^n)

我們可以發現F(x)=Q(x)G(x)(mod  xnm+1)F(x)=Q(x)G(x)(\mod x^{n-m+1}).

這樣求出QQ後即可得到q,rq,r.

代碼:

#include<ctime>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define gc getchar()//(p1==p2&&(p2=(p1=buf)+fread(buf,1,N,stdin),p1==p2)?EOF:*p1++)
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int N=1<<22|10,mod=998244353;

char buf[N],*p1=buf,*p2=buf;
template<class o>void qr(o &x) {
	char c=gc; x=0;
	while(!isdigit(c))c=gc;
	while(isdigit(c))x=(x*10+c-'0')%mod,c=gc;
}
template<class o>void qw(o x) {
	if(x/10) qw(x/10);
	putchar(x%10+'0');
}
template<class o>void pr1(o x) {qw(x); putchar(' ');}
template<class o>void pr2(o x)  {qw(x); puts("");}

ll power(ll a,ll b=mod-2,ll p=mod) {
	ll c=1;
	while(b) {
		if(b&1) c=c*a%p;
		b /= 2; a=a*a%p;
	}
	return c;
}

namespace Cipolla {
	ll p=mod,w;
	struct CP {
		ll x,y;
		CP(ll a=1,ll b=0) {x=a; y=b;}
		CP operator *(CP b) const {return CP( ((x*b.x+y*b.y%p*w)%p+p)%p , ((x*b.y+y*b.x)%p+p)%p);}
	} ;
	CP power(CP a,ll b=(p+1)/2) {
		CP c;
		while(b) {
			if(b&1) c=c*a;
			b /= 2; a=a*a;
		}
		return c;
	}
	bool pd(ll x) {return ::power(x,(p-1)/2,p)==p-1;}
	ll solve(ll n) {
		if(n<=1) return n;
		ll a; do {
			a=(rand()<<15|rand())%p;
			w=((a*a-n)%p+p)%p;
		} while(!pd(w));
		ll x=power(CP(a,1)).x,y=p-x;
		return min(x,y);
	}
}

namespace P {
	const int g=3,inv2=(mod+1)/2;
	int R[N],w[N],Inv[N];
	int calc(int x) {if((x&-x)==x) return x; int n=1; while(n<x) n*=2; return n;}//輸入長度 
	void init(int m) {
		int n=calc(m)*4; 
		Inv[1]=1; for(int i=2;i<n;i++) Inv[i]=(ll)Inv[mod%i]*(mod-mod/i)%mod;
		for(int i=1;i<n;i*=2) {//枚舉半區間長度,把對應的單位根填入w數組 
			ll t=power(g,(mod-1)/(2*i)),d=1;
			for(int j=0;j<i;j++) w[i+j]=d,d=d*t%mod;
		}
	}
	int pre(int m) {//輸入總長度 
		int n=calc(m);
		for(int i=1;i<n;i++) R[i]=(R[i>>1]>>1)|(i&1?n>>1:0);
		return n;
	}
	void upd(int &x) {x+=x>>31&mod;}
	void DFT(int *f,int n) {
		static ull p[N];
		for(int i=0;i<n;i++) p[R[i]]=f[i];
		for(int i=1,t;i<n;i*=2) for(int j=0;j<n;j+=2*i)
			for(int k=0;k<i;k++) t=p[j+k+i]*w[i+k]%mod,p[j+k+i]=p[j+k]+mod-t,p[j+k]+=t;
		for(int i=0;i<n;i++) f[i]=p[i]%mod;
	}
	void IDFT(int *f,int n) {
		reverse(f+1,f+n); DFT(f,n); ll inv=power(n);
		for(int i=0;i<n;i++) f[i]=inv*f[i]%mod;
	}
	void copy(int *a,int *b,int n) {memcpy(a,b,sizeof(int[n]));}
	void clear(int *a,int len) {memset(a+len,0,sizeof(int[len]));}
	void clear(int *a,int x,int y) {if(x<y) memset(a+x,0,sizeof(int[y-x]));}
	void dao(int *a,int *b,int n) {
		for(int i=1;i<n;i++) b[i-1]=(ll)a[i]*i%mod;
		b[n-1]=0;
	}
	void ji(int *a,int *b,int n) {
		for(int i=n-1; i;i--) b[i]=(ll)a[i-1]*Inv[i]%mod;
		b[0]=0;
	}
	void mult(int *a,int *b,int n,int m) {
		static int c[N];
		int x=pre(n+m);
		clear(a,n,x); copy(c,b,m); clear(c,m,x);
		DFT(a,x); DFT(c,x);
		for(int i=0;i<x;i++) a[i]=(ll)a[i]*c[i]%mod;
		IDFT(a,x);
	}
	int h[N];
	void getinv(int *a,int *b,int n) {// 
		clear(b,0,2*n); clear(a,n); clear(h,0,2*n); b[0]=power(a[0]);
		for(int p=2;p<=n;p*=2) {
			int x=pre(p*2);
			copy(h,a,p); clear(h,p); DFT(h,x); DFT(b,x);
			for(int i=0;i<x;i++) b[i]=(ll)(2-(ll)b[i]*h[i]%mod+mod)*b[i]%mod;
			IDFT(b,x); clear(b,p);
		}
	}
	void getsqrt(int *a,int *b,int n) {
		static int c[N],f[N];
		clear(c,0,2*n); clear(f,0,2*n); clear(a,n); clear(b,0,2*n);
		b[0]=Cipolla::solve(a[0]); b[1]=0; 
		for(int p=2;p<=n;p*=2) {
			copy(f,a,p); clear(f,p); getinv(b,c,p); mult(f,c,p,p);
			for(int i=0;i<p;i++) b[i]=(ll)(b[i]+f[i])*inv2%mod;
		}
	}
	void getln(int *a,int *b,int n) {// 
		getinv(a,b,n); dao(a,h,n);
		mult(h,b,n,n); ji(h,b,n); clear(h,0,2*n);
	}
	void getexp(int *a,int *b,int n) {
		static int c[N]; clear(b,0,2*n); clear(a,n); b[0]=1;
		for(int p=2;p<=n;p*=2) {
			copy(c,b,p/2); clear(c,p/2); getln(c,b,p);upd(--b[0]);
			for(int i=0;i<p;i++) upd(b[i]=-b[i]+a[i]);
			mult(b,c,p,p); clear(b,p);
		}
	}
	void getdiv(int *a,int *b,int *c,int n,int m) {//c=a(n)/b(m). 
		static int h1[N],h2[N]; 
		int len=n-m+1,x=calc(len);
		copy(h1,a,n); copy(h2,b,m); 
		clear(h1,n,x); clear(h2,m,x);
		reverse(h1,h1+n); reverse(h2,h2+m); 
		getinv(h2,c,x); mult(c,h1,x,x); reverse(c,c+len); clear(c,len);
	}
	void getmod(int *a,int *b,int *c,int n,int m) {
		static int d[N];
		getdiv(a,b,c,n,m);
		copy(d,b,m);mult(c,d,n-m+1,m);
		for(int i=0;i<m;i++) upd(c[i]=a[i]-c[i]);
	}
}

int n,m,t,f[N],g[N],h[N];

int main() {
	qr(n); qr(m); n++; t=P::calc(n); P::init(t);
	for(int i=0;i<n;i++) qr(f[i]);
	P::getsqrt(f,g,t);
	P::getinv(g,h,t);
	P::ji(h,g,t);
	P::getexp(g,h,t);
	for(int i=0;i<t;i++) P::upd(g[i]=f[i]-h[i]);
	P::upd(g[0]+=2); P::upd(g[0]-=f[0]);
	P::getln(g,h,t); P::upd(h[0]+=1-mod);
	P::getln(h,g,t); for(int i=0;i<t;i++) g[i]=(ll)g[i]*m%mod; 
	P::getexp(g,h,t); P::dao(h,f,n); n--;
	for(int i=0;i<n;i++) pr1(f[i]);
	return 0;
}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章