[鏈剖 FFT] LOJ#6289. 花朵

樹形DP的轉移是一個卷積的轉移形式

可以先鏈剖,一個點的輕兒子先合併,然後一條重鏈用分治FFT合併

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>

using namespace std;

typedef vector<int> poly;

const int N=800010,P=998244353;

inline char nc(){
  static char buf[100000],*p1=buf,*p2=buf;
  return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}

inline void read(int &x){
  char c=nc(); x=0;
  for(;c>'9'||c<'0';c=nc());for(;c>='0'&&c<='9';x=x*10+c-'0',c=nc());
}

int n,m,cnt,a[N],G[N],son[N],size[N],p[N],fa[N],t;
struct edge{
  int t,nx;
}E[N<<2];

poly f[N],g[N];
int num,w[2][N],rev[N];

inline void addedge(int x,int y){
  E[++cnt].t=y; E[cnt].nx=G[x]; G[x]=cnt;
  E[++cnt].t=x; E[cnt].nx=G[y]; G[y]=cnt;
}

void pfs(int x,int f){
  size[x]=1; p[++t]=x; fa[x]=f;
  for(int i=G[x];i;i=E[i].nx)
    if(E[i].t!=f){
      pfs(E[i].t,x);
      size[x]+=size[E[i].t];
      if(size[E[i].t]>size[son[x]]) son[x]=E[i].t;
    }
}

inline int Pow(int x,int y){
  int ret=1;
  for(;y;y>>=1,x=1LL*x*x%P) if(y&1) ret=1LL*x*ret%P;
  return ret;
}

inline void Pre(const int &n){
  num=n; int g=Pow(3,(P-1)/n);
  w[0][0]=w[1][0]=1;
  for(int i=1;i<n;i++) w[1][i]=1LL*w[1][i-1]*g%P;
  for(int i=1;i<n;i++) w[0][i]=w[1][n-i];
}

inline void NTT(int *a,int n,int r){
  for(int i=1;i<n;i++) if(rev[i]>i) swap(a[i],a[rev[i]]);
  for(int i=1;i<n;i<<=1)
    for(int j=0;j<n;j+=(i<<1))
      for(int k=0;k<i;k++){
    int x=a[j+k],y=1LL*a[j+k+i]*w[r][num/(i<<1)*k]%P;
    a[j+k]=(x+y)%P; a[j+k+i]=(x-y+P)%P;
      }
  if(!r) for(int i=0,inv=Pow(n,P-2);i<n;i++) a[i]=1LL*a[i]*inv%P;
}

poly operator *(poly a,poly b){
  if(!a.size() || !b.size()) return a.size()?b:a;
  poly ret;
  if(a.size()+b.size()<500){
    ret.resize(a.size()+b.size()-1);
    for(int i=0;i<a.size();i++)
      for(int j=0;j<b.size();j++)
    ret[i+j]=(ret[i+j]+1LL*a[i]*b[j])%P;
    return ret;
  }
  int n,L=0;
  for(n=1;n<=a.size()+b.size();n<<=1,L++); L--;
  for(int i=1;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<L);
  static int tmpa[N],tmpb[N];
  for(int i=0;i<a.size();i++) tmpa[i]=a[i];
  for(int i=0;i<b.size();i++) tmpb[i]=b[i];
  NTT(tmpa,n,1); NTT(tmpb,n,1);
  for(int i=0;i<n;i++) tmpa[i]=1LL*tmpa[i]*tmpb[i]%P;
  NTT(tmpa,n,0); ret.resize(a.size()+b.size()-1);
  for(int i=0;i<ret.size();i++) ret[i]=tmpa[i];
  for(int i=0;i<n;i++) tmpa[i]=tmpb[i]=0;
  return ret;
}

poly operator +(poly a,poly b){
  poly ret; ret.resize(max(a.size(),b.size()));
  for(int i=0;i<a.size();i++) ret[i]=a[i];
  for(int i=0;i<b.size();i++) ret[i]=(ret[i]+b[i])%P;
  return ret;
}

struct polyc{
  poly a00,a01,a10,a11;
  polyc(){}
  polyc(poly a,poly b):a00(a),a11(b){}

  int size(){ return max(max(a00.size(),a01.size()),max(a10.size(),a11.size())); }

  friend polyc operator *(polyc a,polyc b){
    polyc ret;
    ret.a00=a.a00*b.a00+a.a01*b.a00+a.a00*b.a10;
    ret.a01=a.a00*b.a01+a.a00*b.a11+a.a01*b.a01;
    ret.a10=a.a10*b.a00+a.a10*b.a10+a.a11*b.a00;
    ret.a11=a.a10*b.a01+a.a10*b.a11+a.a11*b.a01;
    return ret;
  }

  friend bool operator <(polyc a,polyc b){
    return a.size()>b.size();
  }
};

struct polypair{
  poly a,b;
  polypair(){}
  polypair(poly _a,poly _b):a(_a),b(_b){}

  friend polypair operator *(polypair a,polypair b){
    return polypair(a.a*b.a,a.b*b.b);
  }

  friend bool operator <(polypair a,polypair b){
    return a.a.size()>b.a.size();
  }
};

namespace HuffmanFFT{

  priority_queue<polypair> a;

  void Push(poly _a,poly _b){
    a.push(polypair(_a,_b));
  }

  polypair work(){
    while(a.size()>1){
      polypair A=a.top(); a.pop();
      polypair B=a.top(); a.pop();
      a.push(A*B);
    }
    polypair ret=a.top(); a.pop();
    return ret;
  }

}

namespace DivAndConq{
  vector<polyc> a;

  void Push(poly _a,poly _b){ a.push_back(polyc(_b,_a)); }

  void Clear(){ a.clear(); }

  polyc solve(int l=0,int r=a.size()-1){
    if(l==r) return a[l];
    int mid=l+r>>1;
    return solve(l,mid)*solve(mid+1,r);
  }
}

inline void solve(int x){
  DivAndConq::Clear();
  for(int u=x;u;u=son[u]){
    for(int i=G[u];i;i=E[i].nx)
      if(E[i].t!=fa[u] && E[i].t!=son[u])
    HuffmanFFT::Push(f[E[i].t]+g[E[i].t],g[E[i].t]);
    polypair cur; if(HuffmanFFT::a.size()) cur=HuffmanFFT::work();
    poly U; U.push_back(0); U.push_back(a[u]);
    if(cur.b.size()) cur.b=cur.b*U; else cur.b=U;
    if(!cur.a.size()) cur.a.push_back(1);
    DivAndConq::Push(cur.b,cur.a);
  }
  polyc cur=DivAndConq::solve();
  f[x]=cur.a10+cur.a11; g[x]=cur.a00+cur.a01;
}

int main(){
  read(n); read(m);
  int _m; for(_m=1;_m<=n;_m<<=1); Pre(_m);
  for(int i=1;i<=n;i++) read(a[i]);
  for(int i=1,x,y;i<n;i++)
    read(x),read(y),addedge(x,y);
  pfs(1,0);
  for(int i=t;i;i--)
    if(son[fa[p[i]]]!=p[i]) solve(p[i]);
  poly ans=f[1]+g[1];
  if(ans.size()>m) printf("%d\n",ans[m]);
  else puts("0");
  return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章