【AC自動機】hdu2222 hdu2896 hdu3065 zoj3430 poj2778 hdu2243

AC自動機用於多個模式串與多個母串的匹配。
第一步:根據模式串建立字典樹

int len=strlen(w), r=root;
    for(int i=0;i<len;++i)
    {
    if(tree[r].ch[w[i]])r=tree[r].ch[w[i]];
        else r=tree[r].ch[w[i]]=++cnt;
    }
    ++tree[r].cnt;//cnt爲在該節點結束的模式串的數量

第二步:計算每一個節點的fail指針(與kmp中next數組相似)。首先找到u的父親的fail指針v。若v對應的兒子不爲空,則u的fail指針指向v。否則訪問v的fail指針。

void bfs()
{
    head=tail=0;
    q[++tail]=0;
    int p, tmp;
    while(head<tail)
    {
        p=q[++head];
        for(int i=0;i<128;++i)//根據模式串的字符大小確定i的範圍
        {
            if(tree[p].ch[i])
            {
                tmp=tree[p].ch[i];
                if(p)tree[tmp].fail=tree[tree[p].fail].ch[i];
                q[++tail]=tmp;
            }
            else tree[p].ch[i]=tree[tree[p].fail].ch[i];
        }
    }
}

第三步:在字典樹中查找

void query()
{
    int len=strlen(w), ans=0, p=0, tmp;
    for(int i=0;i<len;++i)
    {
        p=tree[p].ch[w[i]-'a'];
        tmp=p;
        while(tree[tmp].pos)
        {
            ans+=tree[tmp].pos;
            tree[tmp].pos=0;
            tmp=tree[tmp].fail;
        }
    }
    printf("%d\n",ans);
}

模板題:
hdu2222
題目大意:問在一個母串中有多少個模式串出現
直接上模板

#include <iostream>
#include <cstdio>
#include <cstring>
#define MAXN 500005
#define MAXM 1000005
using namespace std;

int n;
char w[MAXM];

struct node
{
    int pos, ch[26], fail;
    inline void init()
    {
        fail=pos=0;
        memset(ch,0,sizeof ch);
    }
}tree[MAXN];
int cnt, root;

void add()
{
    int len=strlen(w), r=root;
    for(int i=0;i<len;++i)
    {
        w[i]-='a';
        if(tree[r].ch[w[i]])r=tree[r].ch[w[i]];
        else r=tree[r].ch[w[i]]=++cnt;
    }
    ++tree[r].pos;
}

int head, tail, q[MAXN];
void bfs()
{
    head=tail=0;
    q[++tail]=0;
    int p, tmp;
    while(head<tail)
    {
        p=q[++head];
        for(int i=0;i<26;++i)
        {
            if(tree[p].ch[i])
            {
                tmp=tree[p].ch[i];
                if(p)tree[tmp].fail=tree[tree[p].fail].ch[i];
                q[++tail]=tmp;
            }
            else tree[p].ch[i]=tree[tree[p].fail].ch[i];
        }
    }
}

void query()
{
    int len=strlen(w), ans=0, p=0, tmp;
    for(int i=0;i<len;++i)
    {
        p=tree[p].ch[w[i]-'a'];
        tmp=p;
        while(tree[tmp].pos)
        {
            ans+=tree[tmp].pos;
            tree[tmp].pos=0;
            tmp=tree[tmp].fail;
        }
    }
    printf("%d\n",ans);
}

int main()
{
    int cas;
    scanf("%d",&cas);
    while(cas--)
    {
        cnt=root=0;
        scanf("%d",&n);
        for(int i=0;i<n;++i)
        {
            scanf("%s",w);
            add();
        }
        bfs();
        scanf("%s",w);
        query();
        for(int i=0;i<=cnt;++i)
            tree[i].init();
    }
    return 0;
}

hdu2896
一定要注意模式串字符的範圍。蒟蒻在此處RE了幾次才發現…
提供指針版
不過還是數組版的好調試一些 233

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define MAXN 100005
#define MAXM 10005
using namespace std;

int n, m;
bool vis[MAXN];
struct node
{
    int pos;
    node *ch[128], *fail;
    inline void init()
    {
        fail=0, pos=0;
        memset(ch,0,sizeof ch);
    }
}tree[MAXN], *cnt, *root;

char w[MAXM];
void add(int j)
{
    int len=strlen(w), v;
    node *r=root;
    for(int i=0;i<len;++i)
    {
        v=w[i];
        if(r->ch[v])r=r->ch[v];
        else r=r->ch[v]=++cnt;
    }
    r->pos=j;
    vis[j]=1;
}

node *q[MAXN];
int head, tail;
void bfs()
{
    head=tail=0;
    q[++tail]=root;
    node *p, *son, *tmp;
    while(head<tail)
    {
        p=q[++head];
        for(int i=0;i<128;++i)
        {
            son=p->ch[i];
            if(son)
            {
                if(p==root)son->fail=p;
                else
                {
                    tmp=p->fail;
                    while(tmp)
                    {
                        if(tmp->ch[i])
                        {
                            son->fail=tmp->ch[i];
                            break;
                        }
                        tmp=tmp->fail;
                    }
                    if(!tmp)son->fail=root;
                }
                q[++tail]=son;
            }
        }
    }
}

int num[MAXM], tmp, tot;
void query(int j)
{
    int len=strlen(w), v;
    node *p=root, *temp;
    for(int i=0;i<len;++i)
    {
        v=w[i];
        while(!p->ch[v]&&p!=root)
            p=p->fail;
        p=p->ch[v];
        if(!p)p=root;
        temp=p;
        while(vis[temp->pos])
        {
            num[++tmp]=temp->pos;
            vis[temp->pos]=0;
            temp=temp->fail;
        }
    }
    if(tmp)
    {
        sort(num+1,num+tmp+1);
        printf("web %d:",j);
        for(int i=1;i<=tmp;++i)
        {
            printf(" %d",num[i]);
            vis[num[i]]=1;
        }
        ++tot, tmp=0;
        puts("");
    }
}

int main()
{
    while(~scanf("%d",&n))
    {
        root=cnt=tree, tot=0;
        for(int i=1;i<=n;++i)
        {
            scanf("%s",w);
            add(i);
        }
        bfs();
        scanf("%d",&m);
        for(int i=1;i<=m;++i)
        {
            scanf("%s",w);
            query(i);
        }
        printf("total: %d\n",tot);
        for(node *p=tree;p<=cnt;++p)
            p->init();
    }
    return 0;
}

hdu3065
這裏就涉及到了計數的問題。
不進行標記。在之前的模板中都加了一個優化,即訪問了一個節點就打上標記。

#include <iostream>
#include <cstdio>
#include <cstring>
#define MAXM 2000005
#define MAXN 50005
using namespace std;

int n, ans[1005];
char w[1005][55], s[MAXM];

struct node
{
    node *ch[26], *fail;
    int pos;
    void init()
    {
        fail=0, pos=0;
        memset(ch,0,sizeof ch);
    }
}tree[MAXN], *root, *cnt;

void add(node *r,int j)
{
    int len=strlen(w[j]), v;
    for(int i=0;i<len;++i)
    {
        v=w[j][i]-'A';
        if(r->ch[v])r=r->ch[v];
        else r=r->ch[v]=++cnt;
    }
    r->pos=j;
}

int head, tail;
node *q[MAXN];
void bfs()
{
    head=tail=0;
    q[++tail]=root;
    node *p, *son, *tmp;
    while(head<tail)
    {
        p=q[++head];
        for(int i=0;i<26;++i)
        {
            son=p->ch[i];
            if(son)
            {
                if(p==root)son->fail=p;
                else
                {
                    tmp=p->fail;
                    while(tmp)
                    {
                        if(tmp->ch[i])
                        {
                            son->fail=tmp->ch[i];
                            break;
                        }
                        tmp=tmp->fail;
                    }
                    if(!tmp)son->fail=root;
                }
                q[++tail]=son;
            }
        }
    }
}

void query()
{
    int len=strlen(s), v;
    node *p=root, *temp;
    for(int i=0;i<len;++i)
    {
        if(s[i]<'A'||s[i]>'Z')
        {
            p=root;
            continue;
        }
        v=s[i]-'A';
        while(p)
        {
            if(p->ch[v])
            {
                p=p->ch[v];
                break;
            }
            p=p->fail;
        }
        if(!p)p=root;
        else
        {
            temp=p;
            while(temp)
            {
                if(temp->pos)++ans[temp->pos];
                temp=temp->fail;
            }
        }
    }
}

int main()
{
    while(~scanf("%d",&n))
    {
        root=cnt=tree;
        for(int i=1;i<=n;++i)
        {
            scanf("%s",w[i]);
            add(root,i);
        }
        bfs();
        scanf("%s",s);
        query();
        for(int i=1;i<=n;++i)
            if(ans[i])
            {
                printf("%s: %d\n",w[i],ans[i]);
                ans[i]=0;
            }
        for(node *p=tree;p<=cnt;++p)
            p->init();
    }
    return 0;

zoj3430
題目大意:給你一個加密規則:將所有字符寫成二進制並串聯起來。然後每6位數組成一個新的二進制數,再轉化爲十進制數,根據密碼錶翻譯成字符。若len%3==1,就再加上=。若len%3==2,就加上==。
現在已知一些被加密後的模式串和一些加密後的母串。求每一個母串中出現了多少個模式串。

這道題巧妙地運用位運算可以很方便的還原字符串。要注意雖然密碼串是從‘0’到‘z’,但原串是0到255。而且是多組數據,一定要清零。被坑慘了…

#include <iostream>
#include <cstdio>
#include <cstring>
#define MAXN 50050
using namespace std;

char w[MAXN];
bool vis[MAXN];
int n, temp[MAXN], cnt;

struct node
{
    int ch[256], fail, cnt;
    void init()
    {
        cnt=fail=0;
        memset(ch,0,sizeof ch);
    }
}tree[MAXN];

void add(int r)
{
    for(int i=1;i<=temp[0];++i)
    {
        if(tree[r].ch[temp[i]])r=tree[r].ch[temp[i]];
        else r=tree[r].ch[temp[i]]=++cnt;
    }
    tree[r].cnt=1;
}

int key[256];
void table()
{
    for(int i=0;i<26;++i)key[i+'A']=i;
    for(int i=0;i<26;++i)key[i+'a']=i+26;
    for(int i=0;i<10;++i)key[i+'0']=i+52;
    key['+']=62, key['/']=63;
}

void change(char s[])
{
    temp[0]=0;
    int len, x=0;
    for(len=strlen(s);s[len-1]=='=';--len);
    for(int i=0, tmp=0;i<len;++i)
    {
        x=(x<<6)|key[s[i]], tmp+=6;
        if(tmp>=8)
        {
            temp[++temp[0]]=(x>>(tmp-8))&255;
            tmp-=8;
        }
    }
}

int head, tail, q[MAXN];
void bfs()
{
    head=tail=0;
    q[++tail]=0;
    int p, tmp;
    while(head<tail)
    {
        p=q[++head];
        for(int i=0;i<256;++i)
        {
            if(tree[p].ch[i])
            {
                tmp=tree[p].ch[i];
                if(p)tree[tmp].fail=tree[tree[p].fail].ch[i];
                q[++tail]=tmp;
            }
            else tree[p].ch[i]=tree[tree[p].fail].ch[i];
        }
    }
}

void query(int w[])
{
    int ans=0, p=0, tmp;
    for(int i=1;i<=w[0];++i)
    {
        p=tree[p].ch[w[i]];
        tmp=p;
        while(vis[tmp])
        {
            ans+=tree[tmp].cnt;
            vis[tmp]=0;
            tmp=tree[tmp].fail;
        }
    }
    printf("%d\n",ans);
}

int main()
{
    table();
    while(~scanf("%d",&n))
    {
        cnt=0;
        for(int i=1;i<=n;++i)
        {
            scanf("%s",w);
            change(w);
            add(0);
        }
        bfs();
        scanf("%d",&n);
        while(n--)
        {
            memset(vis,1,sizeof vis);
            scanf("%s",w);
            change(w);
            query(temp);
        }
        for(int i=0;i<=cnt;++i)
            tree[i].init();
        puts("");
    }
    return 0;
}

poj2778
題目大意:有n 個病毒的DNA序,求長度爲l 的DNA序中不含病毒的個數。

不得不承認蒟蒻是看了題解的…
首先對這n 個病毒DNA序構建AC自動機。自動集中的節點都是由有向邊連接而成的,那麼將那些打了標記的節點刪除,剩下的就是一個圖。問題就轉化爲從起點0開始,走N步有多少種方案。
這樣就是一個矩陣加速的問題了。

#include <iostream>
#include <cstdio>
#include <cstring>
#define MAXN 110
#define LL long long int
#define mod 100000
using namespace std;

struct mat
{
    LL num[MAXN][MAXN], n;
    void init()
    {
        memset(num,0,sizeof num);
        n=0;
    }
    mat operator * (const mat &a)const
    {
        mat ans;
        ans.init();
        ans.n=n;
        for(int i=0;i<=n;++i)
            for(int j=0;j<=n;++j)
                for(int k=0;k<=n;++k)
                    ans.num[i][k]=(ans.num[i][k]+num[i][j]*a.num[j][k])%mod;
        return ans;
    }
}ans;

mat power(mat a,int pos)
{
    mat ans=a;
    while(pos)
    {
        if(pos&1)ans=ans*a;
        a=a*a;
        pos>>=1;
    }
    return ans;
}

inline int getid(char a)
{
    if(a=='A')return 0;
    if(a=='C')return 1;
    if(a=='G')return 2;
    return 3;
}

struct node
{
    int ch[5], fail, cnt;
    void init()
    {
        fail=cnt=0;
        memset(ch,0,sizeof ch);
    }
}tree[MAXN];

int cnt, root;
char w[MAXN];
void add()
{
    int len=strlen(w), r=root, v;
    for(int i=0;i<len;++i)
    {
        v=getid(w[i]);
        if(tree[r].ch[v])r=tree[r].ch[v];
        else r=tree[r].ch[v]=++cnt;
    }
    ++tree[r].cnt;
}

int q[MAXN], head, tail;
void bfs()
{
    head=tail=0;
    q[++tail]=0;
    int p, tmp;
    while(head<tail)
    {
        p=q[++head];
        for(int i=0;i<4;++i)
        {
            if(tree[p].ch[i])
            {
                tmp=tree[p].ch[i];
                if(p)
                {
                    tree[tmp].fail=tree[tree[p].fail].ch[i];
                    tree[tmp].cnt+=tree[tree[tree[p].fail].ch[i]].cnt;
                }
                q[++tail]=tmp;
            }
            else
                tree[p].ch[i]=tree[tree[p].fail].ch[i];
        }
    }
}

void build()
{
    ans.init();
    for(int i=0;i<=cnt;++i)
    {
        for(int j=0;j<4;++j)
        {
            if(tree[tree[i].ch[j]].cnt)continue;
            ++ans.num[i][tree[i].ch[j]];
        }
    }
    ans.n=cnt;
}

int n, m;
LL out;
int main()
{
    while(~scanf("%d%d",&n,&m))
    {
        root=cnt=0;
        for(int i=0;i<n;++i)
        {
            scanf("%s",w);
            add();
        }
        bfs();
        build();
        ans=power(ans,m-1);
        out=0;
        for(int i=0;i<=cnt;++i)out=(out+ans.num[0][i])%mod;
        printf("%d\n",out);
        for(int i=0;i<=cnt;++i)tree[i].init();
    }
    return 0;
}

hdu2243
這道題就是上面那道的加強版。
蒟蒻目前TLE中,期待持續更新…

發佈了75 篇原創文章 · 獲贊 4 · 訪問量 4萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章