loj#2325. 「清华集训 2017」小 Y 和恐怖的奴隶主 (矩阵快速幂优化概率dp)

吐槽请无视
哇塞我终于开始更博客了!感不感动!兴不兴奋!%¥#%$#@*&....
emm事实上是因为csdn的LaTeX终于修复好了。。

ps.

之后的题解可能都会相对简略。
并且养成标题上加算法的好习惯,,

题面在这里

题意:

维护一个集合,初始有两个数 {+,m}
进行 n 次操作,每次随机选一个数,把它减一;如果结果为 0 ,把它删掉;否则,如果集合大小不超过 k ,则添一个 m
最后问那个 + 期望被减了多少。
询问 n 的次数 T1000 (极限数据为 500 ),n1018,m3,k8

做法:

定义 fi,a,b,c 表示第 i 轮,血量为1/2/3的分别剩下a/b/c 个奴隶主,此时的期望次数;将所有合法 a,b,c 的状态找出来发现最多165种,于是重新定义, fi,S 表示第 i 轮,状态为 S 的期望次数,fi,S=1num+ssfi+1,S 。用矩阵快速幂优化这个dp。
然后每个询问重新计算很慢,于是先 logn 预处理,复杂度 O(1653logn+T1652logn)
加一些卡常优化,,诸如开个大模数减少模的次数。

代码:

#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N = 10;
const int M = 170;
const int mod = 998244353;
const ll lim = (0x7fffffffffffffffll/mod-mod)*mod;
int m, K, tot, id[N][N][N];
ll n;
ll inv[N], tmp[M], ans[M];
inline ll ksm(ll x, ll p) {
    ll ret = 1;
    for(; p; p >>= 1, x = x*x%mod) if(p&1) ret = ret*x%mod;
    return ret;
}
struct matrix {
    ll s[M][M];
    matrix() { memset(s, 0, sizeof s); }
} a[65];
inline matrix sqr(const matrix &x) {
    matrix ret;
    for(int i = 1; i <= tot+1; i ++)
        for(int j = 1; j <= tot+1; j ++) {
            for(int k = 1; k <= tot+1; k ++) {
                ret.s[i][j] += x.s[i][k]*x.s[k][j];
                if(ret.s[i][j] >= lim) ret.s[i][j] -= lim;
            }
            ret.s[i][j] %= mod;
        }
    return ret;
}
inline void mul(const matrix &x) {
    memset(tmp, 0, sizeof tmp);
    for(int i = 1; i <= tot+1; i ++) {
        for(int j = 1; j <= tot+1; j ++) {
            tmp[i] += ans[j]*x.s[j][i];
            if(tmp[i] >= lim) tmp[i] -= lim;
        }
        tmp[i] %= mod;
    }
    for(int i = 1; i <= tot+1; i ++) ans[i] = tmp[i];
}
int main() {
    int test;
    scanf("%d%d%d", &test, &m, &K);
    for(int i = 0; i <= K; i ++)
        for(int j = 0; j <= ((m>1)?K-i:0); j ++)
            for(int k = 0; k <= ((m>2)?K-i-j:0); k ++) id[i][j][k] = ++ tot;
    for(int i = 0; i <= K+1; i ++) inv[i] = ksm(i, mod-2);
    for(int i = 0; i <= K; i ++)
        for(int j = 0; j <= ((m>1)?K-i:0); j ++)
            for(int k = 0; k <= ((m>2)?K-i-j:0); k ++) {
                int now = id[i][j][k], nk = (i+j+k)<K; ll iv = inv[i+j+k+1];
                if(m >= 1) if(i) a[0].s[now][id[i-1][j][k]] = iv*i%mod;
                if(m >= 2) {
                    if(m == 2) if(j) a[0].s[now][id[i+1][j-1+nk][k]] = iv*j%mod;
                    if(m == 3) if(j) a[0].s[now][id[i+1][j-1][k+nk]] = iv*j%mod;
                }
                if(m >= 3) if(k) a[0].s[now][id[i][j+1][k-1+nk]] = iv*k%mod;
                a[0].s[now][now] = a[0].s[now][tot+1] = iv;
            }
    a[0].s[tot+1][tot+1] = 1;
    for(int i = 1; i <= 63; i ++) a[i] = sqr(a[i-1]);
    while(test --) {
        scanf("%lld", &n);
        memset(ans, 0, sizeof ans);
        if(m == 1) ans[id[1][0][0]] = 1;
        if(m == 2) ans[id[0][1][0]] = 1;
        if(m == 3) ans[id[0][0][1]] = 1;
        for(int i = 0; n; n >>= 1, i ++) if(n&1) mul(a[i]);
        printf("%lld\n", ans[tot+1]);
    }
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章