AGC019 E.Shuffle and Swap-DP+NTT

傳送門

題意:

給出兩個01串A,b,記ai 表示A中1的出現位置,bi 表示B中1的出現位置,將a數組和b數組打亂後依次次交換AaiAbi ,求有幾種方式使得A=B

字符串長度<=10000

Solution:

我們可以把答案拆分成兩步:

1.枚舉a和b的匹配

2.打亂匹配順序

假設我們已經完成了操作1,我們來計算每個匹配所能產生的期望合法方案

嘗試轉化一下模型:對於一個給定的匹配,我們從aibi 連一條有向邊,可以發現這個圖最終由若干個環和若干條鏈構成,且鏈的順序是唯一的

假設有e個Ai=Bi=1 ,m個Ai=1,Bi=0 ,可以發現邊數爲e+m,圖由m條鏈和若干環組成

考慮將e個點分配到m條鏈中,f[i][j] 表示前i條鏈分到j個點的期望合法方案

那麼有轉移:f[i][j]=u=0ujf[i1][ju](u+1)! (爲什麼要除(u+1)!呢?因爲一共加入了u+1條邊,這些邊有(u+1)!種匹配方式,而在這些匹配方式中只有一種是合法的)

最終的答案即爲e!m!(e+m)!j=0jef[m][j]

e!表示點的總分配方式,m!表示鏈的不同排序數,(e+m)!表示總匹配數

樸素做法O(n3) ,可以用NTT+快速冪優化到O(nlog2n)

O(n3) 代碼:

#include<cstdio>
#include<iostream>
#include<cstring>
using namespace std;
char a[100010],b[100010];
int n,num,f[510][510],tot;
const int mod=998244353;
int jc[100010],inv[100010],ans;
int fast_pow(int a,int x)
{
    int ans=1;
    for (;x;x>>=1,a=1ll*a*a%mod)
        if (x&1) ans=1ll*ans*a%mod;
    return ans;
}
int main()
{
    scanf("%s%s",a+1,b+1);
    n=strlen(a+1);
    for (int i=1;i<=n;i++)
    {
        if (a[i]=='1'&&b[i]=='1') num++;
        if (a[i]=='1') tot++;
    }
    jc[0]=1;
    for (int i=1;i<=tot+1;i++) jc[i]=1ll*jc[i-1]*i%mod;
    inv[tot+1]=fast_pow(jc[tot+1],mod-2);
    for (int i=tot;i>=0;i--) inv[i]=1ll*inv[i+1]*(i+1)%mod; 
    f[0][0]=1;
    for (int i=1;i<=tot-num;i++)
    {
        for (int j=0;j<=num;j++)
            for (int k=0;k<=j;k++)
                f[i][j]=(1ll*f[i-1][j-k]*inv[k+1]+f[i][j])%mod; 
        for (int j=0;j<=num;j++) printf("%d ",f[i][j]);cout<<endl;
    }
    for (int i=0;i<=num;i++)
        ans=(ans+f[tot-num][i])%mod;
    printf("%d",1ll*ans*jc[tot-num]%mod*jc[num]%mod*jc[tot]%mod);
}

O(nlog2n) 代碼:

#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
using namespace std;
char a[100010],b[100010];
int n,num,tot;
const int mod=998244353;
int jc[100010],inv[100010],ans;
const int G=3;
int x1[100010],x2[100010];
int fast_pow(int a,int x)
{
    int ans=1;
    for (;x;x>>=1,a=1ll*a*a%mod)
        if (x&1) ans=1ll*ans*a%mod;
    return ans;
}
void change(int y[],int len)
{
    int i,j,k;
    for (i=1,j=len/2;i<len-1;i++)
    {
        if (i<j) swap(y[i],y[j]);
        k=len/2;
        while (j>=k) j-=k,k>>=1;
        if (j<k) j+=k; 
    }
    return;
}
void fft(int y[],int len,int ifi)
{
    change(y,len);
    for (int h=2;h<=len;h<<=1)
    {
        int wn=fast_pow(G,(ifi==1)?(mod-1)/h:mod-1-(mod-1)/h);
        for (int j=0;j<len;j+=h)
        {
            int w=1;
            for (int k=j;k<j+h/2;k++)
            {
                int u=y[k];
                int t=1ll*w*y[k+h/2]%mod;
                y[k]=(u+t)%mod;
                y[k+h/2]=(1ll*u-t+mod)%mod;
                w=1ll*w*wn%mod;
            } 
        } 
    }
    if (ifi==-1)
    {
        int iv=fast_pow(len,mod-2);
        for (int i=0;i<len;i++) y[i]=1ll*y[i]*iv%mod;
    }
}
void add(int len)
{
    fft(x2,len,1);
    for (int i=0;i<len;i++) x2[i]=1ll*x2[i]*x2[i]%mod;
    fft(x2,len,-1);
    for (int i=num+1;i<len;i++) x2[i]=0;
}
int main()
{
    scanf("%s%s",a+1,b+1);
    n=strlen(a+1);
    for (int i=1;i<=n;i++)
    {
        if (a[i]=='1'&&b[i]=='1') num++;
        if (a[i]=='1') tot++;
    }
    jc[0]=1;
    for (int i=1;i<=tot+1;i++) jc[i]=1ll*jc[i-1]*i%mod;
    inv[tot+1]=fast_pow(jc[tot+1],mod-2);
    for (int i=tot;i>=0;i--) inv[i]=1ll*inv[i+1]*(i+1)%mod; 
    for (int i=0;i<=tot;i++) inv[i]=inv[i+1];

    int len=1;
    while (len<=2*num) len<<=1;
    x1[0]=1;
    for (int i=1;i<len;i++) x1[i]=0;
    for (int i=0;i<=num;i++) x2[i]=inv[i];
    for (int i=num+1;i<len;i++) x2[i]=0;

    for (int i=tot-num;i;i>>=1,add(len))
        if (i&1)
        {
            fft(x2,len,1);fft(x1,len,1);
            for (int j=0;j<len;j++) x1[j]=1ll*x1[j]*x2[j]%mod;
            fft(x1,len,-1);fft(x2,len,-1);
            for (int j=num+1;j<len;j++) x1[j]=0;
        }
    int ans=0;
    for (int i=0;i<=num;i++) ans=(ans+x1[i])%mod;
    printf("%d",1ll*ans*jc[tot-num]%mod*jc[num]%mod*jc[tot]%mod);
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章