题意:n个人共有m个饼干,每轮随机选一个饼干随机给一个另外的人,所有饼干都在一个人手里时游戏结束,求期望进行次数。模998244353。
n≤105,m≤3×105
首先肯定是每个人作为最终获得所有饼干的人分别考虑。
但是如果只考虑这个人的话,无法确定游戏进行过程中是否已经在其他人那里结束了。
所以干脆改下规则:设当前考虑的人为x,规定只有x收集完所有饼干后游戏才结束。也就是有人收集完了所有饼干后,如果他不是x,游戏继续进行;否则游戏立即结束。
为了后面讲清楚,这里给一个不正式的严谨定义:设R,R1∼Rn表示游戏遵循的规则,其中R表示任何一个人收集了所有饼干后游戏结束,即原来的规则。Rx(x∈[1,n])表示只有x收集完游戏才结束的新规则。
设在Rx下游戏期望进行的次数为Ex′。即:设f′(x,i)表示游戏进行了i步后结束且x获得了所有饼干的概率,Ex′=∑i=0∞i⋅f′(x,i)
注:这个f′(x,i)和后面的f(x,i)只是为了方便理解定义,对推导没有影响。
我们考虑寻找新的规则和原来的规则的联系
在R下,设Px表示游戏结束时所有饼干在x手上的概率,Ex表示所有饼干在x手上结束的所有情况 的 概率乘以时间 之和。(注意不是期望,概率的分母包括了在其他人那里结束的情况)
即:设f(x,i)表示游戏进行i步后结束,x获得所有饼干的概率,Ex=∑i=0∞i⋅f(x,i),Px=∑i=0∞f(x,i)
有∑i=1nPi=1,∑i=1nEi=ans,ans为题目所求
设C表示在Rj下,现在所有饼干在i手上且i=j,游戏期望还要进行多少步。显然这是个与i,j无关的常数。
考虑用Ex表示出Ex′
为了方便,我们称一个状态为i类关键点,当且仅当这个状态的所有饼干都在i手上。
考虑在Rx下的一场游戏,如果它只有结束状态这一个关键点,期望步数为Ex
否则我们枚举第一个关键点的类别i,显然i=x,不然游戏会提前结束。然后从这个关键点开始就是 从“所有点都在i手上”这个状态开始的Rx游戏,期望步数为C
所以:
Ex′=Ex+i=1∑n[i=x](Ei+PiC)
如果你想不通为什么C要乘上Pi:
前面说过,Ex=∑i=0∞if(x,i)
换句话说,对于每一个i,都有f(x,i)的可能性在i步后 所有饼干都在x手上,在这里就是到达枚举的第一个关键点。
在这之后还需要C步来到最终结束的状态,即∑i=0∞(i+C)f(x,i)
而Px=∑i=0∞f(x,i),所以加上PiC就可以了
拆一下:
Ex′=Ex+i=1∑n[i=x](Ei+PiC)
Ex′=Ex+i=1∑n[i=x]Ei+i=1∑n[i=x]PiC
Ex′=i=1∑nEx+Ci=1∑n[i=x]Pi
Ex′=ans+C(1−px)
对x=1∼n求和
i=1∑nEi′=n⋅ans+C(n−1)
只要求出Ex′和C的值,问题就解决了!
而C是严格包含于Ex′的,所以我们的目标是解决新规则下的问题。
注意到在Rx下,我们并不关心每个饼干具体在谁手上,我们只关心它在不在x手上
所以可以设f(i)表示当前x手上有i个饼干时期望进行次数。
f(i)=⎩⎪⎨⎪⎧1+n−11f(i+1)+n−1n−2f(i)1+mif(i−1)+mm−i(n−11f(i+1)+n−1n−2f(i))0x=00<x<mx=m
解出这个方程就好了
暴力硬推是一种方法,这里有一个很优美的思路:
考虑中间的式子,记为
f(i)=Af(i−1)+Bf(i)+Cf(i+1)+1
Af(i−1)+(B−1)f(i)+Cf(i+1)+1=0
注意到A+B+C=1,设g(i)=f(i)−f(i+1)
A⋅g(i−1)+(A+B−1)g(i)+1=0
g(i)=CA⋅g(i−1)+1
这样推一次就可以把g算出来,求个后缀和就得到了f
然后Ex′=f(ai),C=f(0)
复杂度O(n)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#define MAXN 300005
using namespace std;
inline int read()
{
int ans=0;
char c=getchar();
while (!isdigit(c)) c=getchar();
while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
return ans;
}
const int MOD=998244353;
typedef long long ll;
inline int qpow(int a,int p)
{
int ans=1;
while (p)
{
if (p&1) ans=(ll)ans*a%MOD;
a=(ll)a*a%MOD;p>>=1;
}
return ans;
}
#define inv(x) qpow(x,MOD-2)
inline int add(const int& x,const int& y){return x+y>=MOD? x+y-MOD:x+y;}
inline int dec(const int& x,const int& y){return x<y? x-y+MOD: x-y;}
int a[MAXN],f[MAXN];
int main()
{
int n=read(),m=0;
for (int i=1;i<=n;i++) m+=(a[i]=read());
f[0]=n-1;
for (int i=1;i<m;i++) f[i]=((ll)i*inv(m)%MOD*f[i-1]%MOD+1)*m%MOD*(n-1)%MOD*inv(m-i)%MOD;
for (int i=m;i>=0;i--) f[i]=add(f[i],f[i+1]);
int sum=0;
for (int i=1;i<=n;i++) sum=add(sum,f[a[i]]);
sum=dec(sum,f[0]*(n-1ll)%MOD);
sum=(ll)sum*inv(n)%MOD;
cout<<sum;
return 0;
}