題目鏈接:點擊查看
題目大意:給出一個長度爲 n 的序列 a ,再給出一個長度爲 m 的序列 b ,題目保證序列 b 是嚴格遞增的,我們需要將 a 分割成恰好 m 段,使得每一段的最小值恰好等於 b[ i ] ,問有多少種分割方法
題目分析:我們需要將序列 a 恰好分爲 m 段的話,如果知道每一段的左右區間的取值範圍後,根據乘法原理不難算出答案,但又因爲每一段的左端點和右端點都是不斷變化的,看似不太好直接求解
仔細分析一下不難看出,因爲相鄰的兩段之間是拼接而成的,所以上一段的終點,也就決定了下一段的起點,反之亦然
又因爲序列 b 是嚴格遞增的,且需要求的是序列 a 中的最小值,所以倒着處理比較方便
這樣一來我們就可以求出每一個 b[ i ] 對應在序列 a 上的起點的左右區間,然後利用乘法原理計算答案就好了
舉個例子,就拿樣例 1 來說
序列 a 爲 { 12 10 20 20 25 30 } ,序列 b 爲 { 10 20 30 }
倒着來看的話,如果想要讓 min( a[ l ] : a[ r ] ) = b[ 3 ] 的話,只能在序列 a 中取 [ 6 , 6 ] 這段區間,這樣一來,b[ 2 ] 終點的位置就確定爲 5 了,看一下 b[ 2 ] 的起點,可以選擇 3 也可以選擇 4 ,即在序列 a 中選擇區間 [ 3 , 5 ] 和 [ 4 , 5 ] 都可以滿足 min( al : ar ) = b[ 2 ] ,此時因爲 b[ 1 ] 的起點一定是位置 1 ,而 b[ 1 ] 的終點已經由 b[ 2 ] 的起點決定了,所以這個樣例的答案爲 2
到這裏可能會有一個疑問,假如 b[ k + 1 ] 起點的選擇範圍是 [ x , y ] ,b[ k ] 當前選擇的起點爲 z ,正常來說當 b[ k + 1 ] 這段選擇的起點爲 x 時,b[ k ] 這段選擇的區間是 [ z : x - 1 ] ,這裏的 min( a[ z ] : a[ x - 1 ] ) = b[ k ] 是顯然的,那麼當 b[ k + 1 ] 這段如果選擇的起點是 x + 1 時,如何保證 min( a[ z ] : a[ x ] ) 這段也是 b[ k ] 呢?因爲之前 a[ x ] 這個元素包含在 b[ k + 1 ] 這段中,所以 a[ x ] >= b[ k + 1 ] ,又因爲 b[ k + 1 ] > b[ k ] (已知條件),所以 a[ x ] >= b[ k + 1 ] > b[ k ] ,所以 min( a[ z ] : a[ x ] ) = min( min( a[ z ] : a[ x - 1 ] ) , a[ z ] ) = min( b[ k ] , a[ x ] ) = b[ k ]
說道這裏就可以想到維護一個最小值的後綴然後判斷了,設 mmin[ i ] = min( a[ i ] : a[ n ] ) ,根據上面的那一段可知,如果我們處理到了第 k 個位置,也就是需要處理第 b[ k ] 段的區間,只要 [ k + 1 , m ] 這些區間都合法的話,那麼 mmin[ i ] = b[ k ] 的這些位置都是可以在第 k 段上當做起點的位置,且是連續的,因爲數據比較大,所以用 map 離散的記一下數
代碼:
#include<iostream>
#include<cstdio>
#include<string>
#include<ctime>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<stack>
#include<climits>
#include<queue>
#include<map>
#include<set>
#include<sstream>
#include<cassert>
using namespace std;
typedef long long LL;
typedef unsigned long long ull;
const int inf=0x3f3f3f3f;
const int N=2e5+100;
const int mod=998244353;
int a[N],b[N];
int mmin[N];
map<int,int>cnt;
int main()
{
#ifndef ONLINE_JUDGE
// freopen("input.txt","r",stdin);
// freopen("output.txt","w",stdout);
#endif
// ios::sync_with_stdio(false);
int n,m;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",a+i);
for(int i=1;i<=m;i++)
scanf("%d",b+i);
mmin[n+1]=inf;
for(int i=n;i>=1;i--)
{
mmin[i]=min(mmin[i+1],a[i]);
cnt[mmin[i]]++;
}
if(mmin[1]!=b[1])
return 0*puts("0");
LL ans=1;
for(int i=2;i<=m;i++)
ans=ans*cnt[b[i]]%mod;
printf("%lld\n",ans);
return 0;
}