知道結論才能做,而證明結論纔是這題精髓
(但我根本不會,以下均轉載自大佬博客https://blog.csdn.net/xaphoenix/article/details/72681794?utm_source=blogxgwz2)
---------------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------------
證明真的6,總結一下這個證明實際上感性理解是對於兩個包含A,B的S、T串去掉lcp,那麼A、B較長的那個可以拆成較短的那個加C(長的串前綴減掉短的餘下的),重複此操作直至兩串相同,其實是一個類似輾轉相減法的操作,那麼最後有解的充分必要條件就是S和T是coprime(都是由一個小的01串作爲循環節構成的),之後隨便搞一下就行了。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=3e5+100;
const ll mod=1e9+7;
char S[N],T[N];
int lenS,lenT,n;
ll mi2[2*N],cnt[N],jc[2*N],ijc[2*N],ans=0,A,B,x,y,moha;
void Ad(ll &x,ll y)
{if((x+=y)>=mod)x-=mod;}
void Dw(ll &x,ll y)
{if((x-=y)<0)x+=mod;}
ll sqr(ll x)
{return x*x%mod;}
ll gcd(ll x,ll y)
{return !y?x:gcd(y,x%y);}
ll qpow(ll x,ll y)
{
ll res=1;
while(y)
{
if(y&1)res=res*x%mod;
x=x*x%mod,y>>=1;
}
return res;
}
void pre()
{
mi2[0]=1;
for(int i=1;i<2*N;i++)
mi2[i]=mi2[i-1]*2%mod;
jc[0]=1;
for(int i=1;i<2*N;i++)
jc[i]=jc[i-1]*i%mod;
ijc[2*N-1]=qpow(jc[2*N-1],mod-2);
for(int i=2*N-2;i>=0;i--)
ijc[i]=ijc[i+1]*(i+1)%mod;
ll nw;
for(int i=1;i<=n;i++)
{
Ad(cnt[i],mi2[i]);
Ad(moha,cnt[i]*sqr(n/i)%mod);
for(int j=i+i;j<=n;j+=i)
Dw(cnt[j],cnt[i]);
}
}
ll get(ll x,ll y)
{
if(x==0&&y==0)return moha;
if(x==0||y==0)return 0;
if(x<0)x*=-1,y*=-1;
if(y<0)return 0;
ll g=gcd(x,y);int tmp;
x/=g,y/=g;
tmp=n/max(x,y);
return (mi2[tmp+1]-2+mod)%mod;
}
void chk_same()
{
if(lenS!=lenT)return;
ll tmp=1;
for(int i=1;i<=lenS;i++)
{
if(S[i]=='A'&&T[i]=='B')return;
if(S[i]=='B'&&T[i]=='A')return;
if(S[i]=='?'&&T[i]=='?')tmp=tmp*2%mod;
}
Dw(ans,moha*tmp%mod);//bf not right
Ad(ans,sqr(mi2[n+1]-2+mod)*tmp%mod);//true
}
ll C(ll x,ll y)
{return jc[x]*ijc[y]%mod*ijc[x-y]%mod;}
int main()
{
scanf("%s",S+1),lenS=strlen(S+1);
scanf("%s",T+1),lenT=strlen(T+1);
scanf("%d",&n);
pre(),chk_same();
for(int i=1;i<=lenS;i++)
{
if(S[i]=='A')A++;
if(S[i]=='B')B--;
if(S[i]=='?')x++;
}
for(int i=1;i<=lenT;i++)
{
if(T[i]=='A')A--;
if(T[i]=='B')B++;
if(T[i]=='?')y++;
}
for(ll k=0,nw;k<=x+y;k++)
{
nw=get(A+k-y,B+k-x)*C(x+y,k)%mod;
//cerr<<k<<' '<<nw<<'\n';
ans=(ans+nw)%mod;
}
printf("%lld\n",ans);
}