題意:n個人共有m個餅乾,每輪隨機選一個餅乾隨機給一個另外的人,所有餅乾都在一個人手裏時遊戲結束,求期望進行次數。模998244353。
n≤105,m≤3×105
首先肯定是每個人作爲最終獲得所有餅乾的人分別考慮。
但是如果只考慮這個人的話,無法確定遊戲進行過程中是否已經在其他人那裏結束了。
所以乾脆改下規則:設當前考慮的人爲x,規定只有x收集完所有餅乾後遊戲才結束。也就是有人收集完了所有餅乾後,如果他不是x,遊戲繼續進行;否則遊戲立即結束。
爲了後面講清楚,這裏給一個不正式的嚴謹定義:設R,R1∼Rn表示遊戲遵循的規則,其中R表示任何一個人收集了所有餅乾後遊戲結束,即原來的規則。Rx(x∈[1,n])表示只有x收集完遊戲才結束的新規則。
設在Rx下游戲期望進行的次數爲Ex′。即:設f′(x,i)表示遊戲進行了i步後結束且x獲得了所有餅乾的概率,Ex′=∑i=0∞i⋅f′(x,i)
注:這個f′(x,i)和後面的f(x,i)只是爲了方便理解定義,對推導沒有影響。
我們考慮尋找新的規則和原來的規則的聯繫
在R下,設Px表示遊戲結束時所有餅乾在x手上的概率,Ex表示所有餅乾在x手上結束的所有情況 的 概率乘以時間 之和。(注意不是期望,概率的分母包括了在其他人那裏結束的情況)
即:設f(x,i)表示遊戲進行i步後結束,x獲得所有餅乾的概率,Ex=∑i=0∞i⋅f(x,i),Px=∑i=0∞f(x,i)
有∑i=1nPi=1,∑i=1nEi=ans,ans爲題目所求
設C表示在Rj下,現在所有餅乾在i手上且i=j,遊戲期望還要進行多少步。顯然這是個與i,j無關的常數。
考慮用Ex表示出Ex′
爲了方便,我們稱一個狀態爲i類關鍵點,當且僅當這個狀態的所有餅乾都在i手上。
考慮在Rx下的一場遊戲,如果它只有結束狀態這一個關鍵點,期望步數爲Ex
否則我們枚舉第一個關鍵點的類別i,顯然i=x,不然遊戲會提前結束。然後從這個關鍵點開始就是 從“所有點都在i手上”這個狀態開始的Rx遊戲,期望步數爲C
所以:
Ex′=Ex+i=1∑n[i=x](Ei+PiC)
如果你想不通爲什麼C要乘上Pi:
前面說過,Ex=∑i=0∞if(x,i)
換句話說,對於每一個i,都有f(x,i)的可能性在i步後 所有餅乾都在x手上,在這裏就是到達枚舉的第一個關鍵點。
在這之後還需要C步來到最終結束的狀態,即∑i=0∞(i+C)f(x,i)
而Px=∑i=0∞f(x,i),所以加上PiC就可以了
拆一下:
Ex′=Ex+i=1∑n[i=x](Ei+PiC)
Ex′=Ex+i=1∑n[i=x]Ei+i=1∑n[i=x]PiC
Ex′=i=1∑nEx+Ci=1∑n[i=x]Pi
Ex′=ans+C(1−px)
對x=1∼n求和
i=1∑nEi′=n⋅ans+C(n−1)
只要求出Ex′和C的值,問題就解決了!
而C是嚴格包含於Ex′的,所以我們的目標是解決新規則下的問題。
注意到在Rx下,我們並不關心每個餅乾具體在誰手上,我們只關心它在不在x手上
所以可以設f(i)表示當前x手上有i個餅乾時期望進行次數。
f(i)=⎩⎪⎨⎪⎧1+n−11f(i+1)+n−1n−2f(i)1+mif(i−1)+mm−i(n−11f(i+1)+n−1n−2f(i))0x=00<x<mx=m
解出這個方程就好了
暴力硬推是一種方法,這裏有一個很優美的思路:
考慮中間的式子,記爲
f(i)=Af(i−1)+Bf(i)+Cf(i+1)+1
Af(i−1)+(B−1)f(i)+Cf(i+1)+1=0
注意到A+B+C=1,設g(i)=f(i)−f(i+1)
A⋅g(i−1)+(A+B−1)g(i)+1=0
g(i)=CA⋅g(i−1)+1
這樣推一次就可以把g算出來,求個後綴和就得到了f
然後Ex′=f(ai),C=f(0)
複雜度O(n)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#define MAXN 300005
using namespace std;
inline int read()
{
int ans=0;
char c=getchar();
while (!isdigit(c)) c=getchar();
while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
return ans;
}
const int MOD=998244353;
typedef long long ll;
inline int qpow(int a,int p)
{
int ans=1;
while (p)
{
if (p&1) ans=(ll)ans*a%MOD;
a=(ll)a*a%MOD;p>>=1;
}
return ans;
}
#define inv(x) qpow(x,MOD-2)
inline int add(const int& x,const int& y){return x+y>=MOD? x+y-MOD:x+y;}
inline int dec(const int& x,const int& y){return x<y? x-y+MOD: x-y;}
int a[MAXN],f[MAXN];
int main()
{
int n=read(),m=0;
for (int i=1;i<=n;i++) m+=(a[i]=read());
f[0]=n-1;
for (int i=1;i<m;i++) f[i]=((ll)i*inv(m)%MOD*f[i-1]%MOD+1)*m%MOD*(n-1)%MOD*inv(m-i)%MOD;
for (int i=m;i>=0;i--) f[i]=add(f[i],f[i+1]);
int sum=0;
for (int i=1;i<=n;i++) sum=add(sum,f[a[i]]);
sum=dec(sum,f[0]*(n-1ll)%MOD);
sum=(ll)sum*inv(n)%MOD;
cout<<sum;
return 0;
}