首先發現 的貢獻和 的貢獻是可以分開計算的,把一個DP狀態看成是二維座標上的一個點,轉移就是每次從左邊 或者從下面 轉移。
先考慮計算 的貢獻。
考慮枚舉兩個點 ,計算轉移到 的那一步對 的貢獻,從兩點之間的路徑條數入手可以得到:
這樣顯然無法計算,轉換成枚舉 ,
拆開組合數,
先把最後的 用等比數列求和求出來,然後換成枚舉 ,然後就可以卷積了,
再考慮計算 的貢獻, 同理。
對於每一個 ,可以分別計算貢獻求和:
組合數裏 是因爲最後一步必須選擇縱座標 ,但這樣還是無法計算,因爲 有 個,於是還是考慮枚舉差 ,
其中 ,可以先卷積預處理出來。然後像之前一樣枚舉 再捲起來就行了。
代碼:
#include<iostream>
#include<cstdio>
#include<cstring>
#define N 524290
#define ll long long
#define up(x,y) (x=(x+(y))%mod)
using namespace std;
const int mod=998244353;
const int g=3;
ll ksm(ll a,ll b){ll r=1;for(b=(b+mod-1)%(mod-1);b;b>>=1){if(b&1)r=r*a%mod;a=a*a%mod;}return r;}
ll inv(ll x){return ksm(x,mod-2);}
int n,m,r[N];
ll h,hn,A,B,P,Q,a[N],b[N],ans,fac[N],ifac[N],p[N],q[N],pow_P[N],pow_Q[N],pow_h[N],pow_hn[N],c[N];
void ntt(ll a[],int m,int dft)
{
for(int i=0;i<m;i++)
r[i]=(r[i>>1]>>1)|((i&1)*(m>>1));
for(int i=0;i<m;i++)
if(i<r[i]) swap(a[i],a[r[i]]);
for(int i=1;i<m;i<<=1)
{
ll wn=ksm(g,(mod-1)/(i<<1)*dft);
for(int j=0;j<m;j+=(i<<1))
{
ll wk=1;
for(int k=j;k<j+i;k++)
{
ll x=a[k],y=wk*a[k+i]%mod;
a[k]=(x+y)%mod;a[k+i]=(x-y+mod)%mod;
wk=wk*wn%mod;
}
}
}
if(dft==-1) for(int i=0,t=inv(m);i<m;i++) a[i]=a[i]*t%mod;
}
void multi(ll a[],ll b[],ll c[])
{
ntt(a,m,1);ntt(b,m,1);
for(int i=0;i<m;i++)
c[i]=a[i]*b[i]%mod;
ntt(c,m,-1);
}
int main()
{
scanf("%d%lld%lld%lld",&n,&h,&A,&B);
ll pa,pb,qa,qb;
scanf("%lld%lld%lld%lld",&pa,&pb,&qa,&qb);
P=pa*inv(pb)%mod;Q=qa*inv(qb)%mod;
for(int i=1;i<=n;i++)
scanf("%lld",&a[i]);
for(int i=1;i<=n;i++)
scanf("%lld",&b[i]);
for(m=1;m<((n+1)<<1);m<<=1);
hn=ksm(h,n+1);
fac[0]=1;
for(int i=1;i<=m;i++)
fac[i]=fac[i-1]*i%mod;
ifac[m]=inv(fac[m]);
for(int i=m-1;i>=0;i--)
ifac[i]=ifac[i+1]*(i+1)%mod;
pow_P[0]=pow_Q[0]=pow_h[0]=pow_hn[0]=1;
for(int i=1;i<=n;i++)
{
pow_P[i]=pow_P[i-1]*P%mod;
pow_Q[i]=pow_Q[i-1]*Q%mod;
pow_h[i]=pow_h[i-1]*h%mod;
pow_hn[i]=pow_hn[i-1]*hn%mod;
}
for(int i=1;i<=n;i++)
up(ans,pow_hn[i]*a[i]+pow_h[i]*b[i]);
//cal A B
ll ih=inv(h-1),ihn=inv(hn-1);
for(int i=0;i<n;i++)
{
p[i]=pow_P[i]*pow_hn[i+1]%mod*(pow_hn[n-i]-1)%mod*ifac[i]%mod*ihn%mod;
q[i]=pow_Q[i]*pow_h[i+1]%mod*(pow_h[n-i]-1)%mod*ifac[i]%mod*ih%mod;
}
multi(p,q,c);
for(int i=0;i<(n<<1);i++)
up(ans,(P*A+Q*B)%mod*fac[i]%mod*c[i]);
//cal a
for(int i=0;i<=(n>>1);i++)
swap(a[i],a[n-i]);
memset(p,0,sizeof(p));
for(int i=1;i<=n;i++)
p[i]=pow_hn[i];
multi(a,p,a);
memset(p,0,sizeof(p));
memset(q,0,sizeof(q));
for(int i=0;i<n;i++)
p[i]=pow_P[i]*ifac[i]%mod*a[i+n]%mod;
for(int i=1;i<=n;i++)
q[i]=pow_Q[i]*pow_h[i]%mod*ifac[i-1]%mod;
multi(p,q,c);
for(int i=1;i<(n<<1);i++)
up(ans,fac[i-1]*c[i]%mod);
//cal b
for(int i=0;i<=(n>>1);i++)
swap(b[i],b[n-i]);
memset(p,0,sizeof(p));
for(int i=1;i<=n;i++)
p[i]=pow_h[i];
multi(b,p,b);
memset(p,0,sizeof(p));
memset(q,0,sizeof(q));
for(int i=1;i<=n;i++)
p[i]=pow_P[i]*pow_hn[i]%mod*ifac[i-1]%mod;
for(int i=0;i<n;i++)
q[i]=pow_Q[i]*ifac[i]%mod*b[i+n]%mod;
multi(p,q,c);
for(int i=1;i<(n<<1);i++)
up(ans,fac[i-1]*c[i]%mod);
printf("%lld",ans);
return 0;
}