ac自动机题集和应用

最近需要使用ac自动机。补了一下算法。
https://www.cnblogs.com/sclbgw7/p/9260756.html
https://www.cnblogs.com/sclbgw7/p/9875671.html
大佬的博客
下面说自己的心得(算法的理解要等我多刷一些题在写。)
1 大佬博客说的 辅助根优化,我没有发现。正常的字典树不都是有一个根么。ac自动机也用0做根,没毛病啊。
2 链表可以写trie图优化。数组也可以写trie图优化。只不过trie图优化是信息学选手经常用的东西,而他们又不怎么用链表,所以比较少。
3last优化我还不会。
4 ac的fail指针和 kmp相似得地方在于减少重复匹配。kmp的重复匹配大家都懂,ac的重复匹配并不是只是说本身的重复匹配,还有就是实现不同模式串的跳转。
两种情况跳
(1)在当前已经失配的情况下,更好的利用已匹配的结果,需要用fail(fail就是指向当前 字典树 中 我们已经匹配到的部分 中的存在的 最长后缀),比方说
key abd ab,我们的模式串是abc,第三层字典树,c!=d,我滴妈,我要更好的利用已匹配的结果,abd中b的fail指针就是指的ab中的b,哈哈,然后就成功了记录 了ab的匹配成功。(事实上这个 东西可能是 递归的。就像next数组其实也是递归的。,所以对应query中的第一个 while)
(2)当前模式串中可能存在 abc bc ,即一个关键串是另一个关键串的后缀。我们模式串是abc,匹配到了结果是2。
模板

#include <cstdio>
#include <cstring>
#include <queue>
#include <iostream>
using namespace std;
/*  指针的trie图优化。
    show 函数的作用:把指针地址打出来画图用
*/
const int maxw = 10010;      //最大单词数
const int maxwl = 61;        //最大单词长度
const int maxl = 1001000;    //最大文本长度
const int sigm_size = 26;    //字符集大小
// 
struct Node {
    int sum;//>0表示以该结点为前缀的单词个数,=0表示不是单词结点,=-1表示已经经过计数
    Node* chld[sigm_size];
    Node* fail;
    Node() {
        sum = 0;
        memset(chld, 0, sizeof(chld));
        fail = 0;
    }
};
struct ac_automaton {
    Node* root;
    queue<Node*>q;
    void init() {
        root = new Node;
    }
    int idx(char c) {
        return c - 'a';
    }
    void insert(char *s) {
        Node* u = root;
        q.push(u);
        for(int i = 0; i < s[i]; i++) {
            int c = idx(s[i]);
            if(u->chld[c] == NULL)
                u->chld[c] = new Node;
                q.push(u->chld[c]);
            u = u->chld[c];
            //cout<<u<<endl;
        }
        u->sum++;//以该串为前缀的单词个数++
    }
    void getfail() {
          queue<Node *>q;
          q.push(root);
          int tim=0;
          while(!q.empty()){
               Node *u=q.front();
               q.pop();
               for(int i=0;i<sigm_size;i++){
                   if(u->chld[i]==NULL){
                      if(u==root)u->chld[i]=root;
                      else u->chld[i]=u->fail->chld[i];
                   }
                   else{
                        if(u==root)u->chld[i]->fail=root;
                      else u->chld[i]->fail=u->fail->chld[i];
                      //cout<<tim<<"hhh"<<u->chld[i]<<"**"<<i<<endl;
                      q.push(u->chld[i]);
                   }

               }
          }
    }
    int query(char *t) {
        int cnt = 0;//文本中存在单词的个数
        Node* u = root;
        for(int i = 0; t[i]; i++) {//yasherhs
            int c = idx(t[i]);
            u = u->chld[c];  
            Node* tmp = u;
            while(tmp != root) {    
                if(tmp->sum > 0) {
                    cnt += tmp->sum;
                    tmp->sum = -1;  
                }
                else                
                    break;
                tmp = tmp->fail;    //往其他子树上找
            }
        }
        return cnt;
    }
};
ac_automaton ac;
char txt[maxl];
void show(){
     /*while(!ac.q.empty()){
              Node *u=ac.q.front();
              ac.q.pop();
              for(int i=0;i<sigm_size;i++){
                  if(1==1){
                       //q.push(u->chld[i]);
                       cout<<u->chld[i]<<"**"<<u<<"**"<<i<<" "<<u->fail<<endl;
                  }

              }
        }*/

}
int main()
{
    int n, m;
    char word[maxwl];
    scanf("%d", &n);
    while(n--) {
        scanf("%d", &m);
        ac.init();
        for(int i = 0; i < m; i++) {
            scanf("%s", word);
            ac.insert(word);
        }
        ac.getfail();
        int sum=0;
        scanf("%s", txt);
        printf("%d\n", ac.query(txt));
    }
    return 0;
}

先码几道题,在详细的讲理解。
1 hdu2896

Input 第一行,一个整数N(1<=N<=500),表示病毒特征码的个数。
接下来N行,每行表示一个病毒特征码,特征码字符串长度在20—200之间。 每个病毒都有一个编号,依此为1—N。
不同编号的病毒特征码不会相同。 在这之后一行,有一个整数M(1<=M<=1000),表示网站数。
接下来M行,每行表示一个网站源码,源码字符串长度在7000—10000之间。 每个网站都有一个编号,依此为1—M。
以上字符串中字符都是ASCII码可见字符(不包括回车)。 Output
依次按如下格式输出按网站编号从小到大输出,带病毒的网站编号和包含病毒编号,每行一个含毒网站信息。 web 网站编号: 病毒编号 病毒编号
… 冒号后有一个空格,病毒编号按从小到大排列,两个病毒编号之间用一个空格隔开,如果一个网站包含病毒,病毒数不会超过3个。
最后一行输出统计信息,如下格式 total: 带病毒网站数 冒号后有一个空格。 Sample Input
3 aaa bbb ccc
2 aaabbbccc bbaacc
Sample Output
web 1: 1 2 3 total: 1
即统计每个字符串是否出现,出现则输出标号。
思路:使用ac自动机可以很方便的进行这种多模运算(如果你觉得用 后缀数组之类的算法也可以用,那是你没有理解ac自动机),但是,这道题是卡内存的。用指针写的ac自动机很轻松的MLE。 而使用数组(学名静态指针)可以很轻松的过。
下面贴代码
数组版本

#include <iostream>
#include<cstdio>
#include <cstring>
#include <queue>
using namespace std;
// 链表的普通版本+ trie图优化版本
// 数组的普通版本+trie图优化版本。
/*  写指针版本的竟然会MLE
    为什么用 指针版本的会MLE而其他的写法不会。
    但是用 静态指针需要 一大块空间。
*/
const int siz=130;
const int maxn=1e5;
int trie[maxn][siz];
bool vis[503];
int val[maxn];//每个点的权值。
int fail[maxn];
int root;
int sz;
int m,n;
char x[10005];
int newnode(){
    memset(trie[sz],-1,sizeof(trie[sz]));
    return sz++;
}
void init(){
     sz=0;
     root=newnode();
     memset(fail,-1,sizeof(fail));
}
void insert(char x[],int bh){
     int u=root;
     for(int i=0;x[i];i++){
        int num=x[i];
        if(trie[u][num]==-1){
            trie[u][num]=newnode();
        }
        u=trie[u][num];
     }
     val[u]=bh;
     return ;
}
void getfail(){
     queue<int>q;
     q.push(root);
     while(!q.empty()){
          int u=q.front();q.pop();
          for(int i=0;i<siz;i++){
              int num=i;
              if(trie[u][num]==-1)continue;
              if(u==root){
                 fail[trie[u][num]]=root;
                 q.push(trie[u][num]);continue;
              }
              int temp=fail[u];
              while(temp!=-1){
                  if(trie[temp][num]!=-1){
                      fail[trie[u][num]]=trie[temp][num];
                      break;
                  }
                  temp=fail[temp];
              }
              if(temp==-1)fail[trie[u][num]]=root;
              q.push(trie[u][num]);
          }
     }
}
int query(char x[],int &ans,int bh){
      int u=root;
      memset(vis,false,sizeof(vis));
      //int ans=0;
      for(int i=0;x[i];i++){
          int num=x[i];
          while(u!=0&&trie[u][num]==-1){
              u=fail[u];
          }
          u=trie[u][num];
          if(u==-1)u=0;
          int temp=u;
          while(temp!=0){
               if(val[temp]>0){
                  //ans+=val[temp];
                   vis[val[temp]]=true;
                   //val[temp]=0;
                   //cout<<val[temp]<<"**"<<endl;
               }
               else break;
               //else break;
               temp=fail[temp];
          }
      }
      bool flag=false;
     for(int i=1;i<=m;i++){
         if(vis[i]&&!flag){
            flag=true;
            printf("web %d: %d",bh,i);
         }
         else if(vis[i]){
            printf(" %d",i);
         }
     }
     //printf("\n");
     if(flag){ printf("\n");
            ans++;}
            return 0;
}
int main()
{   //char x[500];
    scanf("%d",&m);
    init();
    for(int i=1;i<=m;i++){
        scanf("%s",x);
        insert(x,i);
    }
    getfail();
    scanf("%d",&n);
    int ans=0;
    for(int i=1;i<=n;i++){
       scanf("%s",x);
       query(x,ans,i);
    }
    printf("total: %d\n",ans);

    return 0;
}

2 指针的MLE版本

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <queue>
#include <cstring>
using namespace std;
const int maxn=150;
char x[10006];
bool vis[520];
int m;
struct node{
   node * chld[maxn];
   node *fail;
   int sum;
   node(){
      memset(chld,NULL,sizeof(chld));
      fail=NULL;
      sum=0;
   }
};
node *root;
void init(){
    root=new node;return ;
}
void insert(char x[],int bh){
    node *temp=root;
    for(int i=0;x[i];i++){
        int num=x[i];
        if(temp->chld[num]==NULL){
            temp->chld[num]=new node;
        }
        temp=temp->chld[num];
    }
    temp->sum=bh;
    return ;
}
void getfail(){
    queue<node *>q;
     q.push(root);
     while(!q.empty()){
           node *u=q.front();q.pop();
           for(int i=0;i<maxn;i++){
               if(u->chld[i]==NULL)continue;
               if(u==root){
                   u->chld[i]->fail=root;q.push(u->chld[i]);continue;
                   //continue;
                   // 指向根节点。
               }
               node *temp=u->fail;
               while(temp!=NULL){
                   if(temp->chld[i]!=NULL){
                       u->chld[i]->fail=temp->chld[i];
                       break;
                   }
                   temp=temp->fail;
               }
               if(temp==NULL)u->chld[i]->fail=root;
               q.push(u->chld[i]);
           }
     }
     return ;
}
int query(int bh,char x[],int &ans){
     node *u=root;
     memset(vis,false,sizeof(vis));
     //cout<<"??????"<<endl;
     //int ans=0;
     for(int i=0;x[i];i++){
         int num=x[i];
         //cout<<i<<endl;
         while(u!=root&&u->chld[num]==NULL){
              u=u->fail;
         }
         u=u->chld[num];
         if(u==NULL)u=root;
         node *temp=u;
         while(temp!=root){
              if(temp->sum>0){
                 //ans+=temp->sum;
                 //temp->sum=0;
                 vis[temp->sum]=true;
                // temp->sum=0;
              }
              else break;
              temp=temp->fail;
         }
     }
     //puts("yes");
     bool flag=false;
     for(int i=1;i<=m;i++){
         if(vis[i]&&!flag){
            flag=true;
            printf("web %d: %d",bh,i);
         }
         else if(vis[i]){
            printf(" %d",i);
         }
     }
     //printf("\n");
     if(flag){ printf("\n");
            ans++;}
            return 0;
     //return ans;
}
int main()
{
    //cout << "Hello world!" << endl;

    scanf("%d",&m);
    init();
    for(int i=1;i<=m;i++){
        scanf("%s",x);
        insert(x,i);
    }
    getfail();
    int n;
    scanf("%d",&n);
    int ttt=0;
    //cout<<'a'<<endl;
    for(int i=1;i<=n;i++){
      scanf("%s",x);
       query(i,x,ttt);
     // cout<<"??"<<endl;
    }
    printf("total: %d\n",ttt);
    return 0;
}

hdu3065
和上一题一样。不过注意包含的情况,我的思路是每次前进一步时,都要用fail走一遭,当然用trie图优化更好了。。

#include <iostream>
#include<cstdio>
#include <cstring>
#include <queue>
using namespace std;
// 链表的普通版本+ trie图优化版本
// 数组的普通版本+trie图优化版本。
/*  写指针版本的竟然会MLE
    为什么用 指针版本的会MLE而其他的写法不会。
    但是用 静态指针需要 一大块空间。
*/
const int siz=130;
const int maxn=5e5;
int trie[maxn][siz];
int vis[1003];
int val[maxn];//每个点的权值。
int fail[maxn];
int root;
int sz;
int m,n;
char x[1002][60];
 char s[2000006];
int newnode(){
    memset(trie[sz],-1,sizeof(trie[sz]));
    return sz++;
}
void init(){
     sz=0;
     root=newnode();
     memset(fail,-1,sizeof(fail));
}
void insert(char x[],int bh){
     int u=root;
     for(int i=0;x[i];i++){
        int num=x[i];
        if(trie[u][num]==-1){
            trie[u][num]=newnode();
        }
        u=trie[u][num];
     }
     val[u]=bh;
     return ;
}
void getfail(){
     queue<int>q;
     q.push(root);
     while(!q.empty()){
          int u=q.front();q.pop();
          //cout<<u<<endl;
          for(int i=0;i<siz;i++){
              int num=i;

              if(trie[u][num]!=-1){
              if(u==root){
                 fail[trie[u][num]]=root;
                 //q.push(trie[u][num]);
              }
              else{
              int temp=fail[u];
              while(temp!=-1){
                  if(trie[temp][num]!=-1){
                      fail[trie[u][num]]=trie[temp][num];
                      break;
                  }
                  temp=fail[temp];
                  //cout<<"**i know"<<temp<<"**"<<num<<endl;
              }
              if(temp==-1)fail[trie[u][num]]=root;
              }
              q.push(trie[u][num]);

              }
          }
     }
}
int query(char xx[]){
      int u=root;
      memset(vis,0,sizeof(vis));
      //int ans=0;
      for(int i=0;xx[i];i++){
          int num=xx[i];
          //cout<<i<<endl;
          while(u!=0&&trie[u][num]==-1){
              u=fail[u];
          }
          u=trie[u][num];
          if(u==-1)u=0;
          int temp=u;
          while(temp!=0){
               //cout<<temp<<"**"<<i<<endl;
               if(val[temp]>0){
                  //ans+=val[temp];
                   vis[val[temp]]++;
                   //cout<<val[temp]<<endl;
                   //val[temp]=0;
                   //cout<<val[temp]<<"**"<<endl;
               }

               //else break;
               temp=fail[temp];
          }
      }
      bool flag=false;
     for(int i=1;i<=m;i++){
         if(vis[i]){
            flag=true;
            printf("%s: %d\n",x[i],vis[i]);
         }
     }
     //printf("\n");

            return 0;
}
int main()
{
    while(~scanf("%d",&m)){
    init();
    memset(val,0,sizeof(val));
    for(int i=1;i<=m;i++){
        scanf("%s",x[i]);
        insert(x[i],i);
    }
    //cout<<trie[0][100]<<endl;
    getfail();
    //for()
     //cout<<"??"<<endl;
      //int x=trie[root][(int)('a')];
      //int y=trie[x][(int)('b')];
      //cout<<fail[y]<<endl;
      //cout<<trie[root][(int)('b')]<<endl;
      //cout<<val[fail[y]]<<endl;
       scanf("%s",s);


       query(s);
    }
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章