[LOJ2541] 「PKUWC2018」獵人殺(分治+NTT)

題意

  • nn個人,每個人有一個權值wiw_i,每次隨機殺一個人,殺第ii個人的概率是wij[j is alive]\frac{w_i}{\sum_j[j\ is \ alive]},求第一個人最後一個死的概率,對998244353998244353取模。

這題不是很難但是我自己太sb了所以看了很久。

第一個人最後一個死就代表恰好有00個人在第一個人之後死,算這個是個很套路的東西用容斥轉化爲至少,那麼設f(S)f(S)爲至少有SS集合的人在第一個人之後死的概率,那麼ans=(1)Sf(S)ans=\sum(-1)^{|S|}f(S),現在我們轉化成了求f(S)f(S),實際上這個東西我們可以寫出一個式子:
f(S)=w1w1+w[wS]f(S)=\frac{w_1}{w_1+\sum w[w \in S]}這個東西理解起來也不是很難,每次只有選擇SS集合內的人擊殺或者11擊殺的時候纔會影響他們的相對死亡順序,而選中這裏面的人擊殺時必須要擊殺第一個人才滿足SS集合都在11之後死,那麼我們發現f(S)f(S)只與w[wS]\sum_w[w\in S]有關,又因爲題目中的條件有w105\sum w\le10^5,我們可以考慮計算出每個值的容斥係數,可以直接設f[i][j]f[i][j]表示考慮前ii個人,權值和爲jj時的容斥係數,用揹包的方法轉移就是f[i][j]=f[i1][j]f[i1][jwi]f[i][j] = f[i - 1][j] - f[i - 1][j - w_i],這樣子就可以得到5050分的好成績了,但實際上由於數據比較水,在LOJ上可以通過8080分。

我們設出這個東西的生成函數,第iiaixia_ix^i代表當前dpi=aidp_i=a_i,那麼每次轉移就是乘上(1xwi)(1-x^{w_i}),分別代表是否把當前這個人加入集合的決策,那麼這個生成函數就是:
i=2n(1xwi)\prod_{i = 2}^n(1-x^{w_i})
這個東西直接分治,合併兩個區間的信息的時候用NTT計算就好了,這個東西開數組有點麻煩,那麼我們可以類似於線段樹動態開點回收空間的方法一樣,用完以後利用以前的空間,分治出最深的一條鏈深度應該是logn\log n的,那麼我們開logn\log n個數組就好了,複雜度O(nlog2n)O(n\log ^2n)

#include <bits/stdc++.h>

#define x first
#define y second
#define pb push_back
#define mp make_pair
#define inf (0x3f3f3f3f)

using namespace std;

typedef long long ll;
typedef pair<int, int> PII;

template<class T>inline T read(T &_) {
	T __ = getchar(), ___ = 1; _ = 0;
	for (; !isdigit(__); __ = getchar()) if (__ == '-') ___ = -1;
	for (; isdigit(__); __ = getchar()) _ = (_ << 3) + (_ << 1) + (__ ^ 48);
	return _ *= ___;
}

template<class T>inline bool chkmax(T &_, T __) { return _ < __ ? _ = __, 1 : 0; }
template<class T>inline bool chkmin(T &_, T __) { return _ > __ ? _ = __, 1 : 0; }

inline void proStatus() {
	ifstream t("/proc/self/status");
	cerr << string(istreambuf_iterator<char>(t), istreambuf_iterator<char>());
}

const int N = 1 << 18 | 1; 
const int mod = 998244353;

int w[N], f[N], A[N], B[N], rev[N], S[N], tp[33][N], cnt = -1;

inline int add(int x, int y) { return (x += y) < mod ? x : x - mod; }

inline int qpow(int _, int __) {
	int ___ = 1; 
	for (; __; __ >>= 1, _ = 1ll * _ * _ % mod) 
		if (__ & 1) ___ = 1ll * ___ * _ % mod;
	return ___;
}

inline void NTT(int *a, int n, int fh) {
	for (int i = 0; i < n; ++ i) 
		if (i < rev[i]) swap(a[i], a[rev[i]]);
	for (int Wn, limit = 2; limit <= n; limit <<= 1) {
		Wn = qpow(fh ^ 1 ? qpow(3, mod - 2) : 3, (mod - 1) / limit);
		for (int W = 1, j = 0; j < n; j += limit, W = 1) 
			for (int i = j; i < j + (limit >> 1); ++ i, W = 1ll * W * Wn % mod) {
				int a1 = a[i], a2 = 1ll * a[i + (limit >> 1)] * W % mod; 
				a[i] = add(a1, a2), a[i + (limit >> 1)] = add(a1, mod - a2);
			}
	}
	if (fh ^ 1) for (int inv = qpow(n, mod - 2), i = 0; i < n; ++ i) 
		a[i] = 1ll * a[i] * inv % mod;
}

inline void calc(int *a, int *b, int *c, int limit) {
	NTT(a, limit, 1), NTT(b, limit, 1);
	for (int i = 0; i < limit; ++ i) 
		c[i] = 1ll * a[i] * b[i] % mod;
	NTT(c, limit, -1);
}

inline void Solve(int l, int r, int *a) {
	if (l == r) return (void) (a[0] = 1, a[w[l]] = mod - 1);
	int mid = (l + r) >> 1, limit = 1, k = 0, a1 = ++ cnt, a2 = ++ cnt; 
	Solve(l, mid, tp[a1]), Solve(mid + 1, r, tp[a2]);
	for (; limit <= S[r] - S[l - 1]; ++ k) limit <<= 1; 
	for (int i = 0; i < limit; ++ i) 
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
	calc(tp[a1], tp[a2], a, limit), cnt -= 2;
	for (int i = 0; i < limit; ++ i) tp[a1][i] = tp[a2][i] = 0;
}

int main() {
#ifdef ylsakioi
	freopen("2541.in", "r", stdin);
	freopen("2541.out", "w", stdout);
#endif

	int n, ans = 0; 

	read(n);
	for (int i = 1; i <= n; ++ i) 
		S[i] = S[i - 1] + read(w[i]);
	Solve(2, n, f);
	for (int i = 0; i <= S[n] - S[1]; ++ i) 
		ans = add(ans, 1ll * f[i] * qpow(w[1] + i, mod - 2) % mod);
	printf("%lld\n", 1ll * ans * w[1] % mod);

	return 0;
}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章