CodeForces - 1264D2 Beautiful Bracket Sequence(生成函數 + 組合計數)

在這裏插入圖片描述
在這裏插入圖片描述

大致題意

給你一個由左右括號和?組成的字符串,現在?可以替換成左右括號的任意一個。定義一個字符串的深度爲最大的左右括號嵌套數。現在問,所有的替換方案產生的字符串的深度總和是多少。

做法

如果有nn括號,那麼就會有2n2^n個字符串,顯然直接計算不可以。

考慮一個字符爲 ‘(’ 的位置ii,如果他要對最後的深度產生影響,當且僅當它和它左邊的 ‘(’ 數目小於等於它右邊的 ‘)’ 數目。那麼,我們就可以考慮枚舉每一個位置,看每一個位置對最後答案的貢獻。

考慮位置ii,它和它左邊有aa個左括號,bb個問號,它右邊有cc個右括號,dd個問號。那麼顯然,它的貢獻就是:
a+x<c+yC(b,x)C(d,y)\sum_{a+x<c+y}C(b,x)*C(d,y)
z=dyz=d-y,那麼可以寫成:
x+z<c+daC(b,x)C(d,z)\sum_{x+z<c+d-a}C(b,x)*C(d,z)
我們知道,對於組合數C(n,m)C(n,m),他的生成函數是(x+1)n(x+1)^n,其中第mm項係數就是C(n,m)C(n,m)的答案。那麼上式可以寫成兩個生成函數的乘積。
(x+1)b(x+1)d=(x+1)b+d(x+1)^b*(x+1)^d=(x+1)^{b+d}
由於x+z<c+dax+z<c+d-a,所以答案就是:
k<c+daC(b+d,k)\sum_{k<c+d-a}C(b+d,k)
然後我們發現,b+db+d相當於是所有問號的個數是一個固定的值,也就只需要求一次組合數的前綴和即可。因此直接對每個爲左括號的位置統計貢獻即可。對於爲問號的位置,由於它也可以變成左括號,因此我們要把他們變成左括號統計他們的貢獻,這時b+db+d會減一。相當於總的只需要求兩次組合數的前綴和即可,時間複雜度O(N)O(N)

代碼

懶得寫線性求逆了,實際複雜度O(NlogN)O(NlogN),但是可以寫到O(N)O(N)

#include<bits/stdc++.h>
#define INF 0x3f3f3f3f
#define eps 1e-5
#define pi 3.141592653589793
#define LL long long
#define pb push_back
#define fi first
#define se second
#define lb lower_bound
#define ub upper_bound
#define bug(x) cerr<<#x<<"      :   "<<x<<endl
#define sc(x) scanf("%d",&x)
#define scc(x,y) scanf("%d%d",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;

const int mod = 998244353;
const int N = 1e6 + 7;

unordered_map<int,vector<int> > mp;
char s[N];

LL qpow(LL x,LL n)
{
    LL res=1;
    while(n)
    {
        if (n&1) res=res*x%mod;
        x=x*x%mod; n>>=1;
    }
    return res;
}

LL cal(LL l,LL r,LL L,LL R)
{
    LL n=L+R,m=r+R-l;
    if (m<0) return 0;
    if (n<m) m=n;
    if (mp.count(n)) return mp[n][m];
    LL res=1,sum=1;
    std::vector<int> v;
    for(int i=0;i<=n;i++)
    {
        v.pb(sum);
        res=res*(n-i)%mod*qpow(i+1,mod-2)%mod;
        sum=(sum+res)%mod;
    }
    mp[n]=v;
    return mp[n][m];
}

int main(int argc, char const *argv[])
{
    LL ans=0;
    scanf("%s",s);
    int len=strlen(s);
    int l=0,L=0,r=0,R=0;
    for(int i=len-1;i>=0;i--)
    {
        if (s[i]==')') r++;
        if (s[i]=='?') R++;
    }
    for(int i=0;i<len;i++)
    {
        if (s[i]=='(') l++;
        if (s[i]=='?') L++,R--;
        if (s[i]==')') r--;
        if (s[i]=='(')ans=(ans+cal(l,r,L,R))%mod;
        if (s[i]=='?')
        {
            l++,L--;
            ans=(ans+cal(l,r,L,R))%mod;
            l--,L++;
        }
    }
    printf("%lld\n",ans);
    return 0;
}
發佈了391 篇原創文章 · 獲贊 138 · 訪問量 12萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章