多項式開根
跟 套路是一樣的,套一個牛頓迭代就出來了。
給出多項式 ,求出一個多項式 滿足 。
設 ,那麼就是要求出 的零點。
考慮倍增,當 時,這題保證了 ,故 。假設此時已經求出了 ,滿足 ,根據牛頓迭代,有
於是就做完了,代碼如下:
#include <cstdio>
#include <cstring>
#include <cmath>
#include <vector>
#include <algorithm>
using namespace std;
#define maxn 600010
#define mod 998244353
#define bin(x) (1<<(x))
int n;
int inv[maxn];
int ksm(int x,int y){int re=1;for(;(y&1?re=1ll*re*x%mod:0),y;y>>=1,x=1ll*x*x%mod);return re;}
#define INV(x) ksm(x,mod-2)
struct NTT{
vector<int>w[30];NTT(){
inv[1]=1;for(int i=2;i<=maxn-10;i++)inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
for(int i=1,wn;i<=21;i++){
w[i].resize(bin(i));w[i][0]=1;wn=ksm(3,(mod-1)/bin(i));
for(int j=1;j<bin(i-1);j++)w[i][j]=1ll*w[i][j-1]*wn%mod;
}
}
int r[maxn],limit;void dft(int *f,int lg,int type=0)
{
limit=bin(lg);if(type)reverse(f+1,f+limit);
for(int i=1;i<limit;i++){r[i]=(r[i>>1]>>1)|((i&1)<<(lg-1));if(i<r[i])swap(f[i],f[r[i]]);}
for(int mid=1,Lg=1;mid<limit;mid<<=1,Lg++)for(int j=0;j<limit;j+=(mid<<1))for(int i=0;i<mid;i++)
{int t=1ll*w[Lg][i]*f[j+i+mid]%mod;f[j+i+mid]=(f[j+i]-t+mod)%mod;f[j+i]=(f[j+i]+t)%mod;}
}
}ntt;
int A[maxn],B[maxn],C[maxn],M;
struct POLY{
vector<int> a;int len;void rs(int N){a.resize(len=N);}POLY(){rs(M);}
int &operator [](int x){return a[x];}
friend const POLY operator *(POLY A_,const int x){for(int i=0;i<A_.len;i++)A_[i]=1ll*A_[i]*x%mod;return A_;}
void dft(int *A_,int lg,int ln){for(int i=0;i<bin(lg);i++)A_[i]=(i<min(ln,len)?a[i]:0);ntt.dft(A_,lg);}
void idft(int *A_,int lg,int ln){ntt.dft(A_,lg,1);rs(ln);for(int i=0;i<ln;i++)a[i]=1ll*A_[i]*inv[bin(lg)]%mod;}
const POLY Mul(POLY b,int ln=M){
int lg=ceil(log2(2*ln-1)),limit=bin(lg);dft(A,lg,ln);b.dft(B,lg,ln);
for(int i=0;i<limit;i++)B[i]=1ll*A[i]*B[i]%mod;b.idft(B,lg,ln);return b;
}
}F,G;
void getinv(POLY &f,POLY &g,int ln=M)
{
if(ln==1){g.rs(1);g[0]=INV(f[0]);return;}getinv(f,g,(ln+1)>>1);
int lg=ceil(log2(2*ln-1));f.dft(A,lg,ln);g.dft(B,lg,ln);
for(int i=0;i<bin(lg);i++)B[i]=1ll*(2ll-1ll*A[i]*B[i]%mod+mod)%mod*B[i]%mod;g.idft(B,lg,ln);
}
void getSqrt(POLY &f,POLY &g,int ln=M)
{
if(ln==1){g.rs(1);g[0]=1;return;}getSqrt(f,g,(ln+1)>>1);POLY p=g*2,pp;getinv(p,pp,ln);
int lg=ceil(log2(ln+ln-1));f.dft(A,lg,ln);g.dft(B,lg,ln);pp.dft(C,lg,ln);
for(int i=0;i<bin(lg);i++)C[i]=1ll*(1ll*B[i]*B[i]%mod+A[i])%mod*C[i]%mod;g.idft(C,lg,ln);
}
int main()
{
scanf("%d",&n);F.rs(n);M=n;
for(int i=0;i<n;i++)scanf("%d",&F[i]);
getSqrt(F,G,bin((int)ceil(log2(n))));
for(int i=0;i<n;i++)printf("%d ",G[i]);
}