題目鏈接:https://www.luogu.com.cn/problem/P5664
觀察題目數據範圍,發現前64pts可以用類似狀壓的思想來做(m<=3)。前84pts可以在O(n^3*m)的時間內完成。100pts需要在O(n^2*m)的時間內做。
總述:
注意總的初始化,初始化要爲1,因爲後面有乘的操作,最後的時候再將那個多餘的1減去。
64pts:
設f[i][j][k][q]表示到第i行,第1列選了j個,第2列選了k個,第3列選了q個。
轉移有四種情況:1.不選 2.選的是第一列的 3.選的是第二列的 4.選的是第三列的
則f[i][j][k][q]=f[i-1][j][k][q]+f[i-1][j-1][k][q]*a[i][1]+f[i-1][j][k-1][q]*a[i][2]+f[i-1][j][k][q-1]*a[i][3]
因爲根據題意,要求j<=(j+k+q)/2,k<=(j+k+q)/2,q<=(j+k+q)/2,化簡爲:
j<=k+q,k<=j+q,q<=j+k。只有滿足這個,ans+=f[n][j][k][q]。
代碼:
1 #include<cstdio> 2 #include<iostream> 3 using namespace std; 4 typedef long long ll; 5 const ll mod=998244353; 6 const int N=50; 7 ll f[N][N][N][N],a[N][N]; 8 ll n,m,ans; 9 int main(){ 10 scanf("%lld%lld",&n,&m); 11 for(int i=1;i<=n;i++) 12 for(int j=1;j<=m;j++) scanf("%lld",&a[i][j]); 13 f[0][0][0][0]=1; 14 for(int i=1;i<=n;i++){ 15 for(int j=0;j<=i;j++) 16 for(int k=0;k<=i;k++) 17 for(int q=0;q<=i;q++){ 18 f[i][j][k][q]=(f[i][j][k][q]+f[i-1][j][k][q])%mod; 19 if(j) f[i][j][k][q]=(f[i][j][k][q]+f[i-1][j-1][k][q]*a[i][1])%mod; 20 if(k) f[i][j][k][q]=(f[i][j][k][q]+f[i-1][j][k-1][q]*a[i][2])%mod; 21 if(q) f[i][j][k][q]=(f[i][j][k][q]+f[i-1][j][k][q-1]*a[i][3])%mod; 22 } 23 } 24 for(int i=0;i<=n/2;i++) 25 for(int j=0;j<=n/2;j++) 26 for(int k=0;k<=n/2;k++){ 27 if(i<=k+j&&j<=i+k&&k<=i+j) ans+=f[n][i][j][k],ans%=mod; 28 } 29 printf("%lld\n",ans-1); 30 return 0; 31 }
84pts:
枚舉列數j,設f[i][k][q]表示到第i行,第j列一共選了k個,其餘所有列一共選了q個。
轉移有三種情況:1.不選 2.選的是第j列的 3.選的不是第j列的
則f[i][k][q]=f[i-1][k][q]+f[i-1][k-1][q]*a[i][j]+f[i-1][k][q-1]*(sum[i]-a[i][j])
然後運用容斥原理,用總的方案數減去不符合的(k>q)方案數即爲答案。
代碼:
1 #include<cstdio> 2 #include<iostream> 3 #include<cstring> 4 using namespace std; 5 typedef long long ll; 6 const ll mod=998244353; 7 const int N=55; 8 ll f[N][N][N],a[N][505],sum[N]; 9 ll n,m,ans=1,res; 10 int main(){ 11 scanf("%lld%lld",&n,&m); 12 for(int i=1;i<=n;i++){ 13 for(int j=1;j<=m;j++) scanf("%lld",&a[i][j]),sum[i]=(sum[i]+a[i][j])%mod; 14 sum[i]%=mod; 15 ans=(ans*(sum[i]+1)%mod)%mod; 16 } 17 for(int j=1;j<=m;j++){ 18 memset(f,0,sizeof(f)); 19 f[0][0][0]=1; 20 for(int i=1;i<=n;i++) 21 for(int k=0;k<=i;k++) 22 for(int q=0;q<=i-k;q++){ 23 f[i][k][q]=(f[i][k][q]+f[i-1][k][q])%mod; 24 if(k) f[i][k][q]=(f[i][k][q]+f[i-1][k-1][q]*a[i][j])%mod; 25 if(q) f[i][k][q]=(f[i][k][q]+f[i-1][k][q-1]*(sum[i]-a[i][j]))%mod; 26 } 27 for(int k=1;k<=n;k++){ 28 for(int q=0;q<=n-k;q++){ 29 if(k>q) res+=f[n][k][q]; 30 } 31 res=(res+mod)%mod; 32 } 33 } 34 printf("%lld\n",(ans-res-1+mod)%mod); 35 return 0; 36 }
100pts:
可以發現,在84pts中的做法中,0<=k,q<=n,且k<=q,所以-n<=k-q<=n,然後便可以將上面的壓成二維:
f[i][j]表示選到了第i行,j=k-q+n,j∈[0,2n]。
則f[i][j]=f[i-1][j]+f[i-1][j-1]*a[i][j]+f[i-1][j+1]*(sum[i]-a[i][j])。
其中n+1<=j<=2*n是不符合題意的方案數,減掉即爲答案。注意初始化,因爲整體加了n,所以f[0][n]=1。
代碼:
1 #include<cstdio> 2 #include<iostream> 3 #include<cstring> 4 using namespace std; 5 typedef long long ll; 6 const ll mod=998244353; 7 const int N=105; 8 ll f[N][N<<1],a[N][2005],sum[N]; 9 ll n,m,ans=1,res; 10 int main(){ 11 scanf("%lld%lld",&n,&m); 12 for(int i=1;i<=n;i++){ 13 for(int j=1;j<=m;j++) scanf("%lld",&a[i][j]),sum[i]=(sum[i]+a[i][j])%mod; 14 sum[i]%=mod; 15 ans=(ans*(sum[i]+1)%mod)%mod; 16 } 17 for(int j=1;j<=m;j++){ 18 memset(f,0,sizeof(f)); 19 f[0][n]=1; 20 for(int i=1;i<=n;i++) 21 for(int k=n-i;k<=n+i;k++){ 22 f[i][k]=(f[i][k]+f[i-1][k])%mod; 23 if(k) f[i][k]=(f[i][k]+f[i-1][k-1]*a[i][j])%mod; 24 f[i][k]=(f[i][k]+f[i-1][k+1]*(sum[i]-a[i][j])%mod)%mod; 25 } 26 for(int k=n+1;k<=n*2;k++){ 27 res+=f[n][k]; 28 res=(res+mod)%mod; 29 } 30 } 31 printf("%lld\n",(ans-res-1+mod)%mod); 32 return 0; 33 }