Codeforces Round #414 (Div1+Div2) G Replace All (組合數學)

考慮給定兩個M、N串的情況:
定義:兩個01串S,T(|S||T|)coprime 的當且僅當S=T 。或者如果ST 的一個前綴,並令T=S+X ,如果S,Xcoprime 的,那麼S,T 也是coprime 的。

引理1:如果兩個串S,T(|S||T|)coprime 的,則S+T=T+S

證明:我們對|T| 進行數學歸納。
|T|=1 ,則|S|=1 ,由於S,Tcoprime 的,則S=T ,顯然S+T=T+S ,命題成立。
不妨設當|T|k 時命題成立。
|T|=k+1 時,若此時|S|=k+1 ,則命題成立。
|S|=x<k+1 ,則ST 的一個前綴,設T=S+X ,顯然|X|,|S|k
則我們要證S+S+X=S+X+S ,即S+X=X+S 。由於S,Xcoprime 的,所以上式成立,所以當|T|=k+1 時成立
原命題得證

引理2:MN ,則S,T 一定是coprime 的。
證明方法同上

然後繼續分析固定MN 的答案情況。
如果M=N ,那麼任意一個01 串都可以。
如果MN ,由於引理1和引理2,我們可以將串MN 進行排序,使得A 全在前面,B 全在後面。我們刪去兩個串的最長公共前綴,我們用(MA,MB) 來表示刪去最長公共前綴後串M 擁有的AB 的個數。
如果MA>NA,MB>NB 顯然無解。
先考慮MA>NA,MB<NB ,設x=MANA,y=NBMB 。現在我們需要讓xS=yT ,令T=S+X ,顯然x>y ,那麼代換後可得(xy)S=yX ,不斷如此進行下去可以得到(1,1) 和串XX 。反觀這個過程,我們實際上是(yXX,xXX) 。所以我們只要找到合法的XX 即可。

如果(MA,MB)=(NA,NB) 那麼任意一組coprimeST 都滿足情況。我們需要計算這樣的個數。假設|S|=p,|T|=q ,那麼方案數爲2gcd(p,q) ,證明方法同上。

所以我們只需要計算2gcd(p,q),1p,qN
假設cnt[i] 表示的是i|gcd(p,q)(p,q) 對數,ans[i] 表示的是gcd(p,q)=i 的對數,那麼

ans[i]=j=1Ni(μ(j)cnt[ij])

我們總結下前面的推導。
dA=MANA,dB=MBNB
如果dA=dB=0 ,則答案爲ans[N]
如果dA,dB<0dA,dB>0 無解
否則,設p=|dA|,q=|dB|,d=gcd(p,q) ,那麼ans=2Nmax(p/d,q/d)+12

下來考慮問題的原版本
有上述分析可知,對於確定串的方案數如果MN ,僅和dAdB 相關。所以我們只要能統計出這個的不同方案即可。
假設一開始MN 串中A,B 出現的次數之差爲(p,q)M,N 串中問號的個數爲a,b
假設M,N 中分別將x,y 個問號替換爲A 。那麼新的差爲(p+(xy),q+(ab)(xy)) ,爲了簡化我們將q+(ab)q 來代替。則差爲(p+(xy),q(xy))x 的範圍爲[0,a]y 的範圍爲[0,b]
d=xy ,我們只要枚舉d ,並且計算出造成這個d 的方案數和其對應的答案即可。

對於一個確定的d ,我們計算造成這個d 的方案數

(ax)(by)=(ax)(bby)=(ax)(bb+dx),0xa
這個和爲(a+bb+d)

在最後我們利用DP考慮下兩個串相同的方案數即可。至此,該問題結束。

第一次寫這麼長的題解,完結撒花!

#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<set>
#include<map>
#include<queue>
#include<stack>
#include<vector>
#include<bitset>

using namespace std;
typedef long long LL;
const LL mod=1e9+7;
string s;
string t;
int n;
int l1,l2;
LL pow_mod(LL a,LL e)
{
    LL res=1;
    for (;e;a=a*a%mod,e>>=1) if (e&1) res=res*a%mod;
    return res;
}
int miu[610000];
int prim[610000],primm;
bool valid[610000];
void mobius(int N)
{
    miu[1]=1;
    for (int i=2;i<=N;i++)
    {
        if (!valid[i]) prim[++primm]=i,miu[i]=-1;
        for (int j=1;j<=primm&&i*prim[j]<=N;j++)
        {
            valid[i*prim[j]]=1;
            if (i%prim[j]==0)
            {
                miu[i*prim[j]]=0;
                break;
            }
            else miu[i*prim[j]]=-miu[i];
        }
    }
}
LL now;
int na,nb,ma,mb,nn,mm;
LL cnt[310000],ans[310000];
LL res,all;
LL f[610000],ff[610000],inv[610000];
LL C(int n,int m)
{
    LL res=(f[n]*inv[m])%mod;
    res=(res*inv[n-m])%mod;
    return res;
}
LL gcd(LL a,LL b)
{
    return b==0?a:gcd(b,a%b);
}
int main()
{
    mobius(600000);
    cin>>s;
    cin>>t;
    cin>>n;
    l1=s.size();
    l2=t.size();
    if (l1==l2)
    {
        now=1;
        for (int i=0;i<l1;i++)
            if (s[i]=='A')
            {
                if (t[i]=='A'||t[i]=='?');
                else now=0;
            }
            else if (s[i]=='B')
            {
                if (t[i]=='B'||t[i]=='?');
                else now=0;
            }
            else 
            {
                if (t[i]=='A'||t[i]=='B');
                else now=now*2%mod;
            }
        if (now)
        {
            LL tmp=(pow_mod(2,n+1)-2+mod)%mod;
            tmp=tmp*tmp%mod;
            res=tmp*now%mod;
        }
    }
    for (int i=0;i<l1;i++)
        if (s[i]=='A') na++;
        else if (s[i]=='B') nb++;
        else nn++;
    for (int i=0;i<l2;i++)
        if (t[i]=='A') ma++;
        else if (t[i]=='B') mb++;
        else mm++;
    f[0]=1;
    for (int i=1;i<=600010;i++)
        f[i]=(f[i-1]*i)%mod;
    ff[1]=ff[0]=inv[1]=inv[0]=1;  
    for (int i=2;i<=600010;i++)
    {
        inv[i]=(LL)(mod-mod/i)*inv[mod%i]%mod;
        ff[i]=inv[i];
    }
    for (int i=2;i<=600010;i++)
        inv[i]=(inv[i-1]*inv[i])%mod;
    for (int i=1;i<=n;i++)
        cnt[i]=LL(n/i)*(LL)(n/i);
    for (int i=1;i<=n;i++)
    {
        for (int j=1;j<=n/i;j++)
            ans[i]+=(LL)miu[j]*cnt[i*j];
        all+=ans[i]*pow_mod(2,i);
        all%=mod;
    }
    int p=na-ma,q=nb-mb+(nn-mm);
    for (int i=-mm;i<=nn;i++)
    {
        LL tmp=C(nn+mm,mm+i);
        int np=p+i,nq=q-i;
        if (np==0&&nq==0) tmp=(tmp-now+mod)%mod;
        if (np==nq&&np==0) res=(res+tmp*all)%mod;
        else if (np<=0&&nq<=0||np>=0&&nq>=0);
        else 
        {
            int d=gcd(np,nq);
            np/=d,nq/=d;
            tmp=tmp*(pow_mod(2,n/max(abs(np),abs(nq))+1)-2+mod)%mod;
            res=(res+tmp)%mod;
        }
    }
    cout<<res<<endl;
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章