多項式多點求值
給出一個 次多項式 ,以及一個長度爲 的序列 ,對於每個 ,求出 。
考慮分治,先看區間 ,考慮構造一個 次多項式 ,讓 對 取模,得到 ,其中 就是 除以 的餘數。
將 帶入 ,得到 ,因爲 ,所以有 。
這樣不斷分治下去,每一層的 由上一層的 對 取模得來, 的最高次項會從 到 再到 ……當分治到 ,即最高次項爲常數項時,直接輸出即可。
但是還可以優化一下,即不需要分治到底才輸出,分治到範圍比較小的時候可以暴力求,我的代碼是在 時就直接暴力,比原來的快了 左右。
至於 的話用分治 預處理一下就好了。
代碼如下:
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
#define mod 998244353
#define maxn 400010
#define bin(x) (1<<(x))
#define MS(F,x) memset(F,0,(4<<(x)))
int n,m,a[maxn],F[maxn];
int inv[maxn],log_2[maxn];
int ksm(int x,int y){int re=1;for(;(y&1?re=1ll*re*x%mod:0),y;y>>=1,x=1ll*x*x%mod);return re;}
#define INV(x) ksm(x,mod-2)
int *w[30];void prep(int N){
inv[1]=1;for(int i=2;i<=N;i++)inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod,log_2[i]=ceil(log2(i));
for(int i=1,wn;i<=log_2[N];i++){
w[i]=new int[bin(i-1)];w[i][0]=1;wn=ksm(3,(mod-1)/(bin(i)));
for(int j=1;j<bin(i-1);j++)w[i][j]=1ll*w[i][j-1]*wn%mod;
}
}
int limit,r[maxn];void work(int lg){for(int i=1;i<bin(lg);i++)r[i]=(r[i>>1]>>1)|((i&1)<<(lg-1));}
void ntt(int *f,int lg,int type=0){
limit=bin(lg);if(type)reverse(f+1,f+limit);
for(int i=1;i<limit;i++)if(i<r[i])swap(f[i],f[r[i]]);
for(int mid=1,Lg=1;mid<limit;mid<<=1,Lg++)for(int j=0;j<limit;j+=(mid<<1))for(int i=0;i<mid;i++)
{int t=1ll*f[j+i+mid]*w[Lg][i]%mod;f[j+i+mid]=(f[j+i]-t+mod)%mod;f[j+i]=(f[j+i]+t)%mod;}
}
void NTT(int *f,int *g,int ln){
int lg=log_2[ln*2-1];work(lg);
ntt(f,lg);ntt(g,lg);for(int i=0;i<bin(lg);i++)f[i]=1ll*f[i]*g[i]%mod;
ntt(f,lg,1);for(int i=0;i<bin(lg);i++)f[i]=i<ln?1ll*f[i]*inv[bin(lg)]%mod:0;
}
int A[maxn],B[maxn],C[maxn],D[maxn],E[maxn],H[maxn];
void getinv(int *f,int *g,int ln)
{
if(ln==1){g[0]=INV(f[0]);return;}getinv(f,g,(ln+1)>>1);int lg=log_2[ln<<1];work(lg);
MS(A,lg);MS(B,lg);memcpy(A,f,ln<<2);memcpy(B,g,ln<<2);
ntt(A,lg);ntt(B,lg);for(int i=0;i<bin(lg);i++)A[i]=1ll*(2-1ll*A[i]*B[i]%mod+mod)%mod*B[i]%mod;
ntt(A,lg,1);for(int i=0;i<ln;i++)g[i]=1ll*A[i]*inv[bin(lg)]%mod;
}
void rev(int *f,int *g,int ln){for(int i=0;i<ln;i++)g[i]=f[ln-1-i];}
void getdiv(int *f,int *g,int *q,int ln1,int ln2){
MS(C,log_2[ln1*2]);MS(D,log_2[ln1*2]);rev(f,C,ln1);rev(g,D,ln2);
for(int i=ln1-ln2+1;i<ln1;i++)C[i]=D[i]=0;
MS(E,log_2[ln1*2]);getinv(D,E,ln1-ln2+1);NTT(C,E,ln1-ln2+1);rev(C,q,ln1-ln2+1);
}
void getmod(int *f,int *g,int *q,int *r,int ln1,int ln2){
MS(A,log_2[ln1*2]);MS(B,log_2[ln1*2]);memcpy(A,g,ln1<<2);memcpy(B,q,ln1<<2);
NTT(A,B,ln1);for(int i=0;i<ln2-1;i++)r[i]=(f[i]-A[i]+mod)%mod;
}
void getmod(int *f,int *g,int *r,int ln1,int ln2){
if(ln1<ln2){memcpy(r,f,ln1<<2);return;}
MS(H,log_2[ln1*2]);getdiv(f,g,H,ln1,ln2);getmod(f,g,H,r,ln1,ln2);
}
int *P[maxn],*R[maxn];inline int kk(int d,int id){return (1<<d)+id;}
void FZ_FFT(int l,int r,int d,int id)
{
P[kk(d,id)]=new int[bin(log_2[r-l+2]+1)];R[kk(d,id)]=new int[bin(log_2[r-l+2]+1)];
if(l==r){P[kk(d,id)][0]=mod-a[l];P[kk(d,id)][1]=1;return;}
int mid=l+r>>1;FZ_FFT(l,mid,d+1,id<<1);FZ_FFT(mid+1,r,d+1,(id<<1)+1);
MS(P[kk(d,id)],log_2[r-l+2]+1);MS(E,log_2[r-l+2]+5);
memcpy(P[kk(d,id)],P[kk(d+1,id<<1)],(mid-l+2)<<2);memcpy(E,P[kk(d+1,(id<<1)+1)],(r-mid+1)<<2);
NTT(P[kk(d,id)],E,r-l+2);
}
void evalu(int l,int r,int *FA,int ln,int d,int id)
{
getmod(FA,P[kk(d,id)],R[kk(d,id)],ln,r-l+2);
if(r-l+1<=16){for(int i=l,tot=0;i<=r;i++,tot=0){int x=1;
for(int j=0;j<r-l+2;j++)tot=(tot+1ll*x*R[kk(d,id)][j])%mod,x=1ll*x*a[i]%mod;printf("%d\n",tot);}return;}
int mid=l+r>>1;evalu(l,mid,R[kk(d,id)],r-l+2,d+1,id<<1);evalu(mid+1,r,R[kk(d,id)],r-l+2,d+1,(id<<1)+1);
}
int main()
{
scanf("%d %d",&n,&m);n++;prep(n<<2);
for(int i=0;i<n;i++)scanf("%d",&F[i]);
for(int i=1;i<=m;i++)scanf("%d",&a[i]);
FZ_FFT(1,m,0,0);evalu(1,m,F,n,0,0);
}