超簡單(super)
題目描述
有一個n面的骰子,第i面的數是vi,朝上的概率是pi。
教室的最後一排有一個人,不停地拋這個骰子,直到某一面朝上了兩次,就停止拋骰子,但他不知道所有朝上的面的數字的和的期望E是多少。
老班一臉嘲諷:“這不是超簡單嘛。”
輸入
輸入的第一行包含一個正整數n。
輸入的第二行包含n個正整數,表示vi。
輸入的第三行包含n個非負整數,表示模998244353意義下的pi,保證所有pi的和爲1。
n,vi,pi的含義見問題描述。
輸出
輸出一行一個非負整數E表示模998244353意義下的E。
樣例輸入
<span style="color:#333333"><span style="color:#333333">【樣例輸入】
2
1 2
332748118 665496236
</span></span>
樣例輸出
<span style="color:#333333"><span style="color:#333333">【樣例輸出】
961272344
</span></span>
提示
【樣例說明】
骰子共有2個面。
第一面的數爲1,朝上的概率爲1/3;
第二面的數爲2,朝上的概率爲2/3。
所有情況列舉如下:
第1次朝上的面 |
第2次朝上的面 |
第3次朝上的面 |
朝上的面的和 |
概率 |
1 |
1 |
/ |
2 |
1/9 |
1 |
2 |
1 |
4 |
2/27 |
1 |
2 |
2 |
5 |
4/27 |
2 |
1 |
1 |
4 |
2/27 |
2 |
1 |
2 |
5 |
4/27 |
2 |
2 |
/ |
4 |
4/9 |
所以E=2*1/9+4*2/27+5*4/27+4*2/27+5*4/27+4*4/9=110/27。
【子任務】
測試點 |
n |
vi,pi |
1~4 |
≤8 |
<998244353 |
5~8 |
≤50 |
|
9~12 |
≤100 |
|
13~20 |
≤500 |
solution
期望dp
令f[i][j]表示前i張牌選j張得期望
若f[i][j]=PS,新加入i點,那麼新的期望爲P*pi*(S+vi)
展開得到PS*pi+P*pi*vi
於是我們還需維護期望的和g[i][j]=P轉移式有了
f[i][j]=f[i-1][j]+f[i-1][j-1]*pi+g[i-1][j-1]*vi
我們可以枚舉哪一位爲出現兩次的
效率O(n^3)
jyc神犇有優化
因爲這個dp與數的順序無關(不同順序丟進去出來的是一個結果)
我們可以把最後一維當成我要禁掉的
倒推出n-1維
效率O(n^2)
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
#define maxn 505
#define ll long long
#define mod 998244353
using namespace std;
int n;
ll v[maxn],p[maxn],f[maxn][maxn],g[maxn][maxn],ans,h[maxn];
int main(){
cin>>n;
for(int i=1;i<=n;i++)scanf("%lld",&v[i]);
for(int i=1;i<=n;i++)scanf("%lld",&p[i]);
h[0]=1;
for(int i=1;i<=n;i++)h[i]=(h[i-1]*i)%mod;
for(int i=0;i<=n;i++)g[i][0]=1;
for(int i=1;i<=n;i++)
for(int j=1;j<=i;j++){
f[i][j]=f[i-1][j]+(f[i-1][j-1]*p[i])%mod+((g[i-1][j-1]*p[i])%mod*v[i])%mod;
f[i][j]%=mod;
g[i][j]=g[i-1][j]%mod+g[i-1][j-1]*p[i]%mod;
}
for(int b=1;b<=n;b++){
int vb=v[b],pb=p[b];
for(int j=1;j<=n;j++){
f[n-1][j]=f[n][j]-(f[n-1][j-1]*pb)%mod-((g[n-1][j-1]*pb)%mod*vb)%mod;
f[n][j]%=mod;
g[n-1][j]=g[n][j]-g[n-1][j-1]*pb%mod;
}
n--;
for(int i=0;i<=n;i++){
ll tmp=(f[n][i]*pb)%mod*pb;tmp%=mod;
tmp=tmp+g[n][i]*pb%mod*pb%mod*2*vb%mod;
ans=ans+(tmp*h[i+1])%mod;ans%=mod;
}
n++;
}
ans=(ans%mod+mod)%mod;
cout<<ans<<endl;
return 0;
}