【LOJ6289】花朵(樹上揹包)(NTT)(鏈分治)(帶權二分)

傳送門


題解:

老年選手搞了半個月的文化課開始康復訓練。

很顯然要求的就是個獨立集形式的樹上揹包。

寫成卷積的形式鏈分治+帶權二分即可。

很好寫,拿下LOJ rk1。

複雜度 O(nlog2n)O(n\log^2n),分析方式類似全局平衡二叉樹。

看了下AC代碼,除了我和rk2,剩下的分治部分似乎都是普通二分而不是帶權二分,可以卡到 O(nlog3n)O(n\log^3 n)

rk2可能是慢在取模優化得不夠極限,其實都差不多。


代碼:

#include<bits/stdc++.h>
#define ll long long
#define re register
#define cs const

namespace IO{
	inline char gc(){
		static cs int Rlen=1<<22|1;static char buf[Rlen],*p1,*p2;
		return (p1==p2)&&(p2=(p1=buf)+fread(buf,1,Rlen,stdin),p1==p2)?EOF:*p1++;
	}template<typename T>T get_integer(){
		char c;bool f=false;while(!isdigit(c=gc()))f=c=='-';T x=c^48;
		while(isdigit(c=gc()))x=((x+(x<<2))<<1)+(c^48);return f?-x:x;
	}inline int gi(){return get_integer<int>();}
}using namespace IO;

using std::cerr;
using std::cout;

cs int mod=998244353;
inline int add(int a,int b){return a+b>=mod?a+b-mod:a+b;}
inline int dec(int a,int b){return a-b<0?a-b+mod:a-b;}
inline int mul(int a,int b){ll r=(ll)a*b;return r>=mod?r%mod:r;}
inline void Inc(int &a,int b){a+=b-mod;a+=a>>31&mod;}
inline void Dec(int &a,int b){a-=b;a+=a>>31&mod;}
inline void Mul(int &a,int b){a=mul(a,b);}
inline int po(int a,int b){int r=1;for(;b;b>>=1,Mul(a,a))if(b&1)Mul(r,a);return r;}
inline void ex_gcd(int a,int b,int &x,int &y){
	if(!b){x=1,y=0;return;}ex_gcd(b,a%b,y,x);y-=a/b*x;
}inline int inv(int a){int x,y;ex_gcd(mod,a,y,x);return x+(x>>31&mod);}

template<class ...Args>
inline int add(int a,cs Args& ...args){return add(a,add(args...));}
template<class ...Args>
inline int mul(int a,cs Args& ...args){return mul(a,mul(args...));}

cs int bit=18,SIZE=1<<bit|7;

int r[SIZE],*w[bit+1],Log[SIZE],len,inv_len;
inline void init_omega(){
	for(int re i=1;i<=bit;++i)
		w[i]=new int[1<<(i-1)];
	int wn=po(3,(mod-1)>>bit);
	for(int re i=w[bit][0]=1;i<(1<<(bit-1));++i)
		w[bit][i]=mul(w[bit][i-1],wn);
	for(int re j=bit-1;j;--j)
		for(int re i=0;i<(1<<(j-1));++i)
			w[j][i]=w[j+1][i<<1];
	for(int re i=2;i<SIZE;++i)
		Log[i]=Log[i-1]+((1<<Log[i-1])<i);
}void DFT(int *A){
	for(int re i=1;i<len;++i)
		if(i<r[i])std::swap(A[i],A[r[i]]);
	for(int re i=1,d=1;i<len;i<<=1,++d)
		for(int re j=0;j<len;j+=i<<1)
			if(i<8){
				for(int re k=0;k<i;++k){
					int &t1=A[j+k],&t2=A[j+k+i];
					int t=mul(t2,w[d][k]);
					t2=dec(t1,t);Inc(t1,t);
				}
			}else {
				for(int re k=0;k<i;k+=8){
#define work(p)	\
{				\
	int &t1=A[j+k+p],&t2=A[j+k+i+p];	\
	int t=mul(t2,w[d][k+p]);			\
	t2=dec(t1,t);Inc(t1,t);				\
}
					work(0);work(1);work(2);work(3);
					work(4);work(5);work(6);work(7);
				}
			}
}void IDFT(int *A){
	DFT(A);std::reverse(A+1,A+len);
	for(int re i=0;i<len;++i)Mul(A[i],inv_len);
}void init_len(int l){
	len=l,inv_len=inv(l);
	for(int re i=1;i<l;++i)r[i]=r[i>>1]>>1|((i&1)?l>>1:0);
}

using Poly=std::vector<int>; 
void DFT(Poly &A){DFT(&A[0]);}
void IDFT(Poly &A){IDFT(&A[0]);}

cs int N=8e4+7;

int n,m;

struct matrix{
	Poly a[2][2];matrix(){}
	inline int deg()cs{
		return std::max(
			std::max(a[0][0].size(),a[0][1].size()),
			std::max(a[1][0].size(),a[1][1].size())
		);
	}
	inline void clear(){
		a[0][0].clear(),a[0][1].clear();
		a[1][0].clear(),a[1][1].clear();
	}
	inline cs Poly* operator[](int o)cs{return a[o];}
	inline Poly* operator[](int o){return a[o];}
};

matrix& operator*=(matrix &A,cs matrix &B){
	static int a[2][2][SIZE],b[2][2][SIZE],c[2][2][SIZE];
	init_len(1<<Log[A.deg()+B.deg()-1]);
	for(int re i=0;i<2;++i)for(int re j=0;j<2;++j){
		cs Poly &va=A[i][j],&vb=B[i][j];
		memcpy(a[i][j],&va[0],va.size()<<2);
		memset(a[i][j]+va.size(),0,(len-va.size())<<2);
		memcpy(b[i][j],&vb[0],vb.size()<<2);
		memset(b[i][j]+vb.size(),0,(len-vb.size())<<2);
		DFT(a[i][j]);DFT(b[i][j]);
	}
	for(int re i=0;i<2;++i)for(int re j=0;j<2;++j){
		for(int re p=0;p<len;++p)
			c[i][j][p]=add(mul(a[i][1][p],b[0][j][p]),
				mul(a[i][0][p],add(b[0][j][p],b[1][j][p])));
	}
	for(int re i=0;i<2;++i)for(int re j=0;j<2;++j){
		IDFT(c[i][j]);int d=std::min(len-1,m);
		while(~d&&!c[i][j][d])--d;
		if(~d)A[i][j]=Poly(c[i][j],c[i][j]+d+1);
		else A[i][j].clear();
	}return A;
}

Poly operator*(cs Poly &a,cs Poly &b){
	static int A[SIZE],B[SIZE];
	int deg=a.size()+b.size()-1;init_len(1<<Log[deg]);
	memcpy(A,&a[0],a.size()<<2);
	memset(A+a.size(),0,(len-a.size())<<2);
	memcpy(B,&b[0],b.size()<<2);
	memset(B+b.size(),0,(len-b.size())<<2);
	DFT(A);DFT(B);for(int re i=0;i<len;++i)Mul(A[i],B[i]);
	IDFT(A);deg=std::min(deg-1,m)+1;
	while(deg>0&&!A[deg-1])--deg;return Poly(A,A+deg);
}
Poly& operator+=(Poly &A,cs Poly &B){
	if(A.size()<B.size())A.resize(B.size());
	for(int re i=0;i<(int)B.size();++i)Inc(A[i],B[i]);
	return A;
} 

std::vector<int> G[N];
void adde(int u,int v){
	G[u].push_back(v);
	G[v].push_back(u);
}

matrix dp[N];

int vl[N];

int fa[N],sz[N],son[N];

void pre_dfs(int u,int p){
	fa[u]=p;sz[u]=1;
	for(int re v:G[u])
		if(v!=p){
			pre_dfs(v,u);sz[u]+=sz[v];
			if(sz[v]>sz[son[u]])son[u]=v;
		}
}

int st[N],tp;
Poly p[N];
int pr[N];

Poly merge_son(int l,int r){
	if(l==r)return p[l];
	int m=std::lower_bound(pr+l,pr+r+1,(pr[l]+pr[r])>>1)-pr;
	if(m==r)--m;return merge_son(l,m)*merge_son(m+1,r);
}

void merge_chain(int l,int r){
	if(l==r)return ;
	int m=std::lower_bound(pr+l,pr+r+1,(pr[l]+pr[r])>>1)-pr;
	if(m==r)--m;merge_chain(l,m),merge_chain(m+1,r);
	dp[st[l]]*=dp[st[m+1]];
}

void dfs_solve(int u){
	for(int re v:G[u])
		if(v!=fa[u])dfs_solve(v);
	tp=0;
	for(int re v:G[u])
		if(v!=fa[u]&&v!=son[u]){
			p[++tp].clear();
			p[tp].resize(dp[v].deg());
			p[tp]+=dp[v][0][0];
			p[tp]+=dp[v][0][1];
			p[tp]+=dp[v][1][0];
			p[tp]+=dp[v][1][1];
			pr[tp]=pr[tp-1]+p[tp].size();
		}
	if(tp)dp[u][0][0]=merge_son(1,tp);else dp[u][0][0].push_back(1);
	tp=0;
	for(int re v:G[u])
		if(v!=fa[u]&&v!=son[u]){
			p[++tp].clear();
			p[tp]+=dp[v][0][0];
			p[tp]+=dp[v][0][1];
			pr[tp]=pr[tp-1]+p[tp].size();
		}
	if(tp)dp[u][1][1]=merge_son(1,tp);else dp[u][1][1].push_back(1);
	for(int &v:dp[u][1][1])Mul(v,vl[u]);dp[u][1][1].insert(dp[u][1][1].begin(),0);
	if(u!=son[fa[u]]){tp=0;
		for(int re p=u;p;p=son[p])
			st[++tp]=p,pr[tp]=pr[tp-1]+dp[p].deg();
		merge_chain(1,tp);
	}
}

void Main(){
	n=gi(),m=gi();init_omega();
	for(int re i=1;i<=n;++i)vl[i]=gi();
	for(int re i=1;i<n;++i)adde(gi(),gi());
	pre_dfs(1,0);dfs_solve(1);int ans=0;
	for(int re i=0;i<2;++i)for(int re j=0;j<2;++j)
		if((int)dp[1][i][j].size()>m)Inc(ans,dp[1][i][j][m]);
	cout<<ans<<"\n";
}

inline void file(){
#ifdef zxyoi
	freopen("flower.in","r",stdin);
#endif 
}
signed main(){file();Main();return 0;} 
發佈了993 篇原創文章 · 獲贊 374 · 訪問量 7萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章