【考試題 - 排列】 分治+樹形DP

題意:求有多少種排列滿足 $i$ 之前第一個小於 $i$ 的位置是 $q[i]$.  

如果沒有 $q[i]$ 的限制,答案就是全排列,然後 $q[i]$ 會限制一些元素之間的大小關係.  

直接做的話沒辦法方便地求出元素之間的大小關係.   

不妨思考單調棧的過程:如果遇到前綴最小值的話肯定會將棧清空. 

那麼也就是說如果最小值 $i$ 將序列分爲 $L,R$,則 $L,R$ 之間相互不影響.         

有上述結論後就可以根據最小值進行分治了,會形成一個樹形結構.    

建出樹後令 $f[x]$ 表示以 $x$ 爲根的子樹有多少種排列滿足限制,然後轉移的話乘上一個組合數就好了. 

code: 

#include <cstdio>  
#include <vector>
#include <cstring>
#include <algorithm>    
#define N 500008 
#define ll long long  
#define mod 998244353
#define setIO(s) freopen(s".in","r",stdin)
using namespace std; 
int n,edges; 
int fac[N],inv[N],g[20][N],q[N],Lg[N],hd[N],to[N],nex[N],f[N],size[N];    
void add(int u,int v) { 
    nex[++edges]=hd[u]; 
    hd[u]=edges,to[edges]=v;  
}
int qpow(int x,int y) { 
    int tmp=1; 
    for(;y;y>>=1,x=(ll)x*x%mod) {    
        if(y&1) tmp=(ll)tmp*x%mod; 
    } 
    return tmp; 
}  
inline int get_inv(int x) { 
    return qpow(x,mod-2); 
}
void init() { 
    fac[0]=1; 
    for(int i=1;i<N;++i) fac[i]=(ll)fac[i-1]*i%mod; 
    inv[1]=1; 
    for(int i=2;i<N;++i) inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;  
    inv[0]=1;  
    for(int i=1;i<N;++i) inv[i]=(ll)inv[i-1]*inv[i]%mod; 
} 
int C(int x,int y) { 
    return (ll)fac[x]*inv[y]%mod*inv[x-y]%mod;  
}        
void build() { 
    for(int i=1;(1<<i)<=n;++i) 
        for(int j=1;j+(1<<i)-1<=n;++j) {
            int a=g[i-1][j],b=g[i-1][j+(1<<(i-1))];   
            if(q[b]<=q[a]) g[i][j]=b; 
            else g[i][j]=a;  
        }
    Lg[1]=0; 
    for(int i=2;i<N;++i) { 
        Lg[i]=Lg[i>>1]+1;  
    }
}
int query(int l,int r) { 
    int det=Lg[r-l+1];  
    return q[g[det][r-(1<<det)+1]]<=q[g[det][l]]?g[det][r-(1<<det)+1]:g[det][l];  
}
int solve(int l,int r) {   
    if(l>r) return 0; 
    int now=query(l,r);   
    if(q[now]!=l-1) {  
        printf("0\n"); 
        exit(0);  
    }    
    int a=solve(l,now-1),b=solve(now+1,r);    
    if(a) add(now,a);  
    if(b) add(now,b);   
    return now;  
}
void dfs(int x) { 
    f[x]=1;  
    for(int i=hd[x];i;i=nex[i]) { 
        int y=to[i]; 
        dfs(y),size[x]+=size[y];  
        f[x]=(ll)f[x]*C(size[x],size[y])%mod*f[y]%mod;    
    } 
    ++size[x];  
}
int main() {  
    // setIO("input");      
    init(); 
    scanf("%d",&n); 
    for(int i=1;i<=n;++i) { 
        scanf("%d",&q[i]);  
        g[0][i]=i;  
    }  
    build();  
    int p=solve(1,n);   
    dfs(p);  
    printf("%d\n",f[p]);  
    return 0; 
}

  

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章