ZOJ3494 BCD Code (AC自動機+數位DP)

用AC自動機構造出病毒串的trie圖,然後設狀態dp[i][j]表示長度爲i且位於j節點時的符合要求的數的數量,然後按照普通數位DP做即可

遞推式數位DP統計[1,x]內符合條件的數只需要考慮三種情況:

1,位數比x短的數

2,位數和x一樣,但是某一位比x小的數

3,x本身是否符合條件

0比較特殊,在一般的數位DP中需要特殊處理,一般以特判爲主

AC自動機的話,即trie樹+fail樹,理解了fail邊的構造,則寫出AC自動機就不是什麼困難的事了

#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <queue>
using namespace std;
struct AC {
        struct Node {
                Node *ch[2],*fail;
                bool ed,vis;
        }memo[2010],*root;
        int tot;
        void New_node(Node *&o) {
                o = &memo[tot++];
                o->ch[0] = o->ch[1] = NULL;
                o->fail = root;
                o->ed = o->vis = 0;
        }
        void init() {
                tot = 0;
                New_node(root);
        }
        void ins(char *s) {
                Node *p = root;
                for ( ; *s; s ++) {
                        int c = *s-'0';
                        if (p->ch[c]==NULL) New_node(p->ch[c]);
                        p = p->ch[c];
                }
                p->ed = 1;
        }
        void dfs(Node *p) {
                p->vis = true;
                if (p==root) return ;
                if (p->fail->vis==false) dfs(p->fail);
                p->ed |= p->fail->ed;
        }
        void build() {
                queue<Node*> que;
                for (int i = 0; i < 2; i ++)
                        if (root->ch[i]!=NULL) que.push(root->ch[i]);
                while (!que.empty()) {
                        Node *f = que.front(); que.pop();
                        for (int i = 0; i < 2; i ++) {
                                if (f->ch[i]!=NULL) {
                                        Node *p = f->fail;
                                        for ( ; p->ch[i]==NULL && p!=root; p = p->fail);
                                        f->ch[i]->fail = p->ch[i]==NULL ? root : p->ch[i];
                                        que.push(f->ch[i]);
                                }
                        }
                }
                for (int i = 0; i < tot; i ++) if (!memo[i].vis) dfs(memo+i);
        }
        int match(int v,char *s) {
                Node *p = &memo[v];
                if (p->ed) return -1;
                for ( ; *s; s ++) {
                        int c = *s-'0';
                        while (p->ch[c]==NULL && p!=root) p = p->fail;
                        if (p->ch[c]!=NULL) p = p->ch[c];
                        if (p->ed) return -1;
                }
                return p-memo;
        }
}ac;

typedef long long lld;
const int MOD = (int)1e9+9;
char s[222],snum[10][6] = {"0000","0001","0010","0011","0100","0101","0110","0111","1000","1001"};
lld dp[222][2010];
int g[2010][10];
void add(lld &a,lld b) { a += b; if (a>=MOD) a -= MOD; }
void init() {
        for (int i = 0; i < 222; i ++)
                for (int j = 0; j < ac.tot; j ++)
                        dp[i][j] = 0;
        for (int i = 0; i < ac.tot; i ++)
                for (int j = 0; j < 10; j ++)
                        g[i][j] = ac.match(i,snum[j]);
        for (int i = 0; i < ac.tot; i ++)
                dp[1][i] = 1;
        for (int i = 1; i < 222-1; i ++)
                for (int j = 0; j < ac.tot; j ++)
                        for (int k = 0; k < 10; k ++) if (g[j][k]!=-1)
                                add(dp[i+1][j],dp[i][g[j][k]]);
}
lld calc(bool mark) {
        lld ret = 0;
        int p = 0,len = strlen(s+1);
        for (int i = 1; i < len-i+1; i ++) swap(s[i],s[len-i+1]);
        for (int i = len; i >= 1; i --) {
                if (p==-1) break;
                for (int j = s[i]-'0'-1; j > 0; j --) {
                        int v = g[p][j];
                        if (v==-1) continue;
                        add(ret,dp[i][v]);
                }
                if (s[i]-'0'>0 && g[p][0]!=-1 && i!=len) 
                        add(ret,dp[i][g[p][0]]);
                p = g[p][s[i]-'0'];
        }
        add(ret,(p!=-1)*mark);
        for (int i = len-1; i >= 1; i --) 
                for (int j = 1; j < 10; j ++) if (g[0][j]!=-1)
                        add(ret,dp[i][g[0][j]]);
        return ret;
}
int main() {
        int cas;
        scanf("%d",&cas);
        while (cas--) {
                int n;
                scanf("%d",&n);
                ac.init();
                for (int i = 0; i < n; i ++) {
                        scanf("%s",s);
                        ac.ins(s);
                }
                ac.build();
                init();
                lld ans = 0;
                scanf("%s",s+1);
                ans -= calc(0);
                scanf("%s",s+1);
                (((ans += calc(1)) %= MOD) += MOD) %= MOD;
                printf("%lld\n",ans);
        }
        return 0;
}


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章