1327G - Letters and Question Marks(AC自動機+狀壓DP)

題目鏈接

題目大意:

kk個字符串t1,t2,...tkt_1,t_2,...t_ktit_i有權值cic_i.令F(T,t)F(T,t)表示字符串TT中包含多少個ttG(T)=i=1kF(T,ti)ciG(T)=\sum_{i=1}^kF(T,t_i)*c_i
現在給出一個字符串SS,SS中有最多14個位置是未知的,你可以在這些位置上填互不相同的字母ana-n,求G(S)G(S)最大可以是多少。
ti1000,S5e4,106ci106\sum |t_i|\le 1000, |S|\le5e4,-10^6\le c_i \le 10^6

解題思路

注意到未知的位置較少,且必須要填互不相同的字母,這提示我們用狀壓DP去寫。
而統計一些模板字符在一個字符串裏面出現的次數和貢獻,可以使用ac自動機求出。在這題中的障礙是那些未知的位置。
注意到ti1000\sum|t_i|\le 1000,AC自動機最多有1000個結點。未知位置最多有14個,所以原本的串SS最多被分成15段已知的固定的串。
我們令nxt[u][i]nxt[u][i]表示ac自動機的結點uu跑一遍SS的第ii段串之後變成了結點nxt[u][i]nxt[u][i]。令sum[u][i]sum[u][i]表示這個過程中得到的貢獻。
我們用dp[u][mask](mask1cnt)dp[u][mask],(假設mask中的1的個數爲cnt)表示:
處理完前cntcnt個未知位置,使用的字符集合爲maskmask,當前位置爲第cnt+1cnt+1段的最後一個字母,在ac自動機上的位置爲結點uu的情況下,得到的G的最大值.
它的轉移如圖表示:
在這裏插入圖片描述
先枚舉當前使用的字符集合mask,然後枚舉上一段的結尾走到了ac自動機的u,根據第cnt個位置填什麼字符來轉移:
轉移的時候有三段貢獻:

  1. 前面的dp值
  2. 從上一段最後一個位置走到第cnt個’?’(填了i)得到的貢獻
  3. 走到cnt+1段的最後一個位置的貢獻
dp[ nxt[ch[u][i]][num] ][mask] =max(dp[ nxt[ch[u][i]][num] ][mask], dp[u][mask^(1<<i)]+cost[ch[u][i]]+sum[ch[u][i]][num]);

ac代碼

#include<bits/stdc++.h>
#define ll long long
#define lowbit(x) ((x)&(-(x)))
#define fors(i, a, b) for(int i = (a); i < (b); ++i)
using namespace std;
const int maxn = 4e5 + 50;
int ch[maxn][15], fail[maxn];
ll cost[maxn], rt, tot = 0;
void ins(char *s, int val){
    int p = rt;
    while(*s){
        int x = *s - 'a';
        if(!ch[p][x]) {
            ch[p][x] = ++tot;
        }
        p = ch[p][x];
        s++;
    }
    cost[p] += val;
}
queue<int> q;
void get_fail()
{
    while(q.size()) q.pop();
    for(int i = 0; i < 15; ++i)
        if(ch[rt][i]) q.push(ch[rt][i]), fail[ch[rt][i]] = rt;
        else ch[rt][i] = rt;
    while(q.size()){
        int cur = q.front(); q.pop();
        for(int i = 0; i < 15; ++i){
            if(ch[cur][i]) {
                fail[ ch[cur][i] ] = ch[ fail[cur] ][i];
                q.push(ch[cur][i]);
                cost[ch[cur][i]] += cost[ fail[ ch[cur][i] ] ];
            }
            else ch[cur][i] = ch[fail[cur]][i];
        }
    }
}
char t[1050];
void init(){
    tot = 0; rt = ++tot;
    int n; scanf("%d", &n);
    fors(i, 0, n){
        int x;
        scanf("%s%d", t, &x); ins(t, x);
    }
    get_fail();
}
char s[maxn];
int pos[20], cnt = 0;
int nxt[1050][17];
ll sum[1050][17];
ll dp[1050][1<<14];
int cal(int x){int res = 0; while(x) res++, x-=lowbit(x); return res;}
void sol(){
    scanf("%s", s);
    int n = strlen(s);
    pos[cnt++] = -1;
    fors(i, 0, n) if(s[i] == '?') pos[cnt++] = i;
    pos[cnt] = n;
    fors(i, 0, cnt){
        fors(u, 1, tot+1){
            int p = u;
            fors(j, pos[i]+1, pos[i+1]){
                p = ch[p][s[j]-'a'];
                sum[u][i] += cost[p];
            }nxt[u][i] = p;
        }
    }
    memset(dp, 0xcf, sizeof dp);
    dp[nxt[rt][0]][0] = sum[rt][0];
    ll ans = -1e18;
    if(cnt == 1) ans = sum[rt][0];//if there is no "?"
    fors(mask, 1, (1<<14)){
        int num = cal(mask);
        if(num > cnt-1) continue;
        fors(u, 1, tot+1){
            fors(i, 0, 14){
                if(mask>>i&1){
                    dp[ nxt[ch[u][i]][num] ][mask] =
                    max(dp[ nxt[ch[u][i]][num] ][mask], dp[u][mask^(1<<i)]+cost[ch[u][i]]+sum[ch[u][i]][num]);
                    if(num == cnt-1) {
                        ans = max(ans, dp[ nxt[ch[u][i]][num] ][mask]);
                    }
                }
            }
        }
    }
    cout<<ans<<endl;
}
int main()
{
    init();
    sol();
}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章