ac自動機題集和應用

最近需要使用ac自動機。補了一下算法。
https://www.cnblogs.com/sclbgw7/p/9260756.html
https://www.cnblogs.com/sclbgw7/p/9875671.html
大佬的博客
下面說自己的心得(算法的理解要等我多刷一些題在寫。)
1 大佬博客說的 輔助根優化,我沒有發現。正常的字典樹不都是有一個根麼。ac自動機也用0做根,沒毛病啊。
2 鏈表可以寫trie圖優化。數組也可以寫trie圖優化。只不過trie圖優化是信息學選手經常用的東西,而他們又不怎麼用鏈表,所以比較少。
3last優化我還不會。
4 ac的fail指針和 kmp相似得地方在於減少重複匹配。kmp的重複匹配大家都懂,ac的重複匹配並不是只是說本身的重複匹配,還有就是實現不同模式串的跳轉。
兩種情況跳
(1)在當前已經失配的情況下,更好的利用已匹配的結果,需要用fail(fail就是指向當前 字典樹 中 我們已經匹配到的部分 中的存在的 最長後綴),比方說
key abd ab,我們的模式串是abc,第三層字典樹,c!=d,我滴媽,我要更好的利用已匹配的結果,abd中b的fail指針就是指的ab中的b,哈哈,然後就成功了記錄 了ab的匹配成功。(事實上這個 東西可能是 遞歸的。就像next數組其實也是遞歸的。,所以對應query中的第一個 while)
(2)當前模式串中可能存在 abc bc ,即一個關鍵串是另一個關鍵串的後綴。我們模式串是abc,匹配到了結果是2。
模板

#include <cstdio>
#include <cstring>
#include <queue>
#include <iostream>
using namespace std;
/*  指針的trie圖優化。
    show 函數的作用:把指針地址打出來畫圖用
*/
const int maxw = 10010;      //最大單詞數
const int maxwl = 61;        //最大單詞長度
const int maxl = 1001000;    //最大文本長度
const int sigm_size = 26;    //字符集大小
// 
struct Node {
    int sum;//>0表示以該結點爲前綴的單詞個數,=0表示不是單詞結點,=-1表示已經經過計數
    Node* chld[sigm_size];
    Node* fail;
    Node() {
        sum = 0;
        memset(chld, 0, sizeof(chld));
        fail = 0;
    }
};
struct ac_automaton {
    Node* root;
    queue<Node*>q;
    void init() {
        root = new Node;
    }
    int idx(char c) {
        return c - 'a';
    }
    void insert(char *s) {
        Node* u = root;
        q.push(u);
        for(int i = 0; i < s[i]; i++) {
            int c = idx(s[i]);
            if(u->chld[c] == NULL)
                u->chld[c] = new Node;
                q.push(u->chld[c]);
            u = u->chld[c];
            //cout<<u<<endl;
        }
        u->sum++;//以該串爲前綴的單詞個數++
    }
    void getfail() {
          queue<Node *>q;
          q.push(root);
          int tim=0;
          while(!q.empty()){
               Node *u=q.front();
               q.pop();
               for(int i=0;i<sigm_size;i++){
                   if(u->chld[i]==NULL){
                      if(u==root)u->chld[i]=root;
                      else u->chld[i]=u->fail->chld[i];
                   }
                   else{
                        if(u==root)u->chld[i]->fail=root;
                      else u->chld[i]->fail=u->fail->chld[i];
                      //cout<<tim<<"hhh"<<u->chld[i]<<"**"<<i<<endl;
                      q.push(u->chld[i]);
                   }

               }
          }
    }
    int query(char *t) {
        int cnt = 0;//文本中存在單詞的個數
        Node* u = root;
        for(int i = 0; t[i]; i++) {//yasherhs
            int c = idx(t[i]);
            u = u->chld[c];  
            Node* tmp = u;
            while(tmp != root) {    
                if(tmp->sum > 0) {
                    cnt += tmp->sum;
                    tmp->sum = -1;  
                }
                else                
                    break;
                tmp = tmp->fail;    //往其他子樹上找
            }
        }
        return cnt;
    }
};
ac_automaton ac;
char txt[maxl];
void show(){
     /*while(!ac.q.empty()){
              Node *u=ac.q.front();
              ac.q.pop();
              for(int i=0;i<sigm_size;i++){
                  if(1==1){
                       //q.push(u->chld[i]);
                       cout<<u->chld[i]<<"**"<<u<<"**"<<i<<" "<<u->fail<<endl;
                  }

              }
        }*/

}
int main()
{
    int n, m;
    char word[maxwl];
    scanf("%d", &n);
    while(n--) {
        scanf("%d", &m);
        ac.init();
        for(int i = 0; i < m; i++) {
            scanf("%s", word);
            ac.insert(word);
        }
        ac.getfail();
        int sum=0;
        scanf("%s", txt);
        printf("%d\n", ac.query(txt));
    }
    return 0;
}

先碼幾道題,在詳細的講理解。
1 hdu2896

Input 第一行,一個整數N(1<=N<=500),表示病毒特徵碼的個數。
接下來N行,每行表示一個病毒特徵碼,特徵碼字符串長度在20—200之間。 每個病毒都有一個編號,依此爲1—N。
不同編號的病毒特徵碼不會相同。 在這之後一行,有一個整數M(1<=M<=1000),表示網站數。
接下來M行,每行表示一個網站源碼,源碼字符串長度在7000—10000之間。 每個網站都有一個編號,依此爲1—M。
以上字符串中字符都是ASCII碼可見字符(不包括回車)。 Output
依次按如下格式輸出按網站編號從小到大輸出,帶病毒的網站編號和包含病毒編號,每行一個含毒網站信息。 web 網站編號: 病毒編號 病毒編號
… 冒號後有一個空格,病毒編號按從小到大排列,兩個病毒編號之間用一個空格隔開,如果一個網站包含病毒,病毒數不會超過3個。
最後一行輸出統計信息,如下格式 total: 帶病毒網站數 冒號後有一個空格。 Sample Input
3 aaa bbb ccc
2 aaabbbccc bbaacc
Sample Output
web 1: 1 2 3 total: 1
即統計每個字符串是否出現,出現則輸出標號。
思路:使用ac自動機可以很方便的進行這種多模運算(如果你覺得用 後綴數組之類的算法也可以用,那是你沒有理解ac自動機),但是,這道題是卡內存的。用指針寫的ac自動機很輕鬆的MLE。 而使用數組(學名靜態指針)可以很輕鬆的過。
下面貼代碼
數組版本

#include <iostream>
#include<cstdio>
#include <cstring>
#include <queue>
using namespace std;
// 鏈表的普通版本+ trie圖優化版本
// 數組的普通版本+trie圖優化版本。
/*  寫指針版本的竟然會MLE
    爲什麼用 指針版本的會MLE而其他的寫法不會。
    但是用 靜態指針需要 一大塊空間。
*/
const int siz=130;
const int maxn=1e5;
int trie[maxn][siz];
bool vis[503];
int val[maxn];//每個點的權值。
int fail[maxn];
int root;
int sz;
int m,n;
char x[10005];
int newnode(){
    memset(trie[sz],-1,sizeof(trie[sz]));
    return sz++;
}
void init(){
     sz=0;
     root=newnode();
     memset(fail,-1,sizeof(fail));
}
void insert(char x[],int bh){
     int u=root;
     for(int i=0;x[i];i++){
        int num=x[i];
        if(trie[u][num]==-1){
            trie[u][num]=newnode();
        }
        u=trie[u][num];
     }
     val[u]=bh;
     return ;
}
void getfail(){
     queue<int>q;
     q.push(root);
     while(!q.empty()){
          int u=q.front();q.pop();
          for(int i=0;i<siz;i++){
              int num=i;
              if(trie[u][num]==-1)continue;
              if(u==root){
                 fail[trie[u][num]]=root;
                 q.push(trie[u][num]);continue;
              }
              int temp=fail[u];
              while(temp!=-1){
                  if(trie[temp][num]!=-1){
                      fail[trie[u][num]]=trie[temp][num];
                      break;
                  }
                  temp=fail[temp];
              }
              if(temp==-1)fail[trie[u][num]]=root;
              q.push(trie[u][num]);
          }
     }
}
int query(char x[],int &ans,int bh){
      int u=root;
      memset(vis,false,sizeof(vis));
      //int ans=0;
      for(int i=0;x[i];i++){
          int num=x[i];
          while(u!=0&&trie[u][num]==-1){
              u=fail[u];
          }
          u=trie[u][num];
          if(u==-1)u=0;
          int temp=u;
          while(temp!=0){
               if(val[temp]>0){
                  //ans+=val[temp];
                   vis[val[temp]]=true;
                   //val[temp]=0;
                   //cout<<val[temp]<<"**"<<endl;
               }
               else break;
               //else break;
               temp=fail[temp];
          }
      }
      bool flag=false;
     for(int i=1;i<=m;i++){
         if(vis[i]&&!flag){
            flag=true;
            printf("web %d: %d",bh,i);
         }
         else if(vis[i]){
            printf(" %d",i);
         }
     }
     //printf("\n");
     if(flag){ printf("\n");
            ans++;}
            return 0;
}
int main()
{   //char x[500];
    scanf("%d",&m);
    init();
    for(int i=1;i<=m;i++){
        scanf("%s",x);
        insert(x,i);
    }
    getfail();
    scanf("%d",&n);
    int ans=0;
    for(int i=1;i<=n;i++){
       scanf("%s",x);
       query(x,ans,i);
    }
    printf("total: %d\n",ans);

    return 0;
}

2 指針的MLE版本

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <queue>
#include <cstring>
using namespace std;
const int maxn=150;
char x[10006];
bool vis[520];
int m;
struct node{
   node * chld[maxn];
   node *fail;
   int sum;
   node(){
      memset(chld,NULL,sizeof(chld));
      fail=NULL;
      sum=0;
   }
};
node *root;
void init(){
    root=new node;return ;
}
void insert(char x[],int bh){
    node *temp=root;
    for(int i=0;x[i];i++){
        int num=x[i];
        if(temp->chld[num]==NULL){
            temp->chld[num]=new node;
        }
        temp=temp->chld[num];
    }
    temp->sum=bh;
    return ;
}
void getfail(){
    queue<node *>q;
     q.push(root);
     while(!q.empty()){
           node *u=q.front();q.pop();
           for(int i=0;i<maxn;i++){
               if(u->chld[i]==NULL)continue;
               if(u==root){
                   u->chld[i]->fail=root;q.push(u->chld[i]);continue;
                   //continue;
                   // 指向根節點。
               }
               node *temp=u->fail;
               while(temp!=NULL){
                   if(temp->chld[i]!=NULL){
                       u->chld[i]->fail=temp->chld[i];
                       break;
                   }
                   temp=temp->fail;
               }
               if(temp==NULL)u->chld[i]->fail=root;
               q.push(u->chld[i]);
           }
     }
     return ;
}
int query(int bh,char x[],int &ans){
     node *u=root;
     memset(vis,false,sizeof(vis));
     //cout<<"??????"<<endl;
     //int ans=0;
     for(int i=0;x[i];i++){
         int num=x[i];
         //cout<<i<<endl;
         while(u!=root&&u->chld[num]==NULL){
              u=u->fail;
         }
         u=u->chld[num];
         if(u==NULL)u=root;
         node *temp=u;
         while(temp!=root){
              if(temp->sum>0){
                 //ans+=temp->sum;
                 //temp->sum=0;
                 vis[temp->sum]=true;
                // temp->sum=0;
              }
              else break;
              temp=temp->fail;
         }
     }
     //puts("yes");
     bool flag=false;
     for(int i=1;i<=m;i++){
         if(vis[i]&&!flag){
            flag=true;
            printf("web %d: %d",bh,i);
         }
         else if(vis[i]){
            printf(" %d",i);
         }
     }
     //printf("\n");
     if(flag){ printf("\n");
            ans++;}
            return 0;
     //return ans;
}
int main()
{
    //cout << "Hello world!" << endl;

    scanf("%d",&m);
    init();
    for(int i=1;i<=m;i++){
        scanf("%s",x);
        insert(x,i);
    }
    getfail();
    int n;
    scanf("%d",&n);
    int ttt=0;
    //cout<<'a'<<endl;
    for(int i=1;i<=n;i++){
      scanf("%s",x);
       query(i,x,ttt);
     // cout<<"??"<<endl;
    }
    printf("total: %d\n",ttt);
    return 0;
}

hdu3065
和上一題一樣。不過注意包含的情況,我的思路是每次前進一步時,都要用fail走一遭,當然用trie圖優化更好了。。

#include <iostream>
#include<cstdio>
#include <cstring>
#include <queue>
using namespace std;
// 鏈表的普通版本+ trie圖優化版本
// 數組的普通版本+trie圖優化版本。
/*  寫指針版本的竟然會MLE
    爲什麼用 指針版本的會MLE而其他的寫法不會。
    但是用 靜態指針需要 一大塊空間。
*/
const int siz=130;
const int maxn=5e5;
int trie[maxn][siz];
int vis[1003];
int val[maxn];//每個點的權值。
int fail[maxn];
int root;
int sz;
int m,n;
char x[1002][60];
 char s[2000006];
int newnode(){
    memset(trie[sz],-1,sizeof(trie[sz]));
    return sz++;
}
void init(){
     sz=0;
     root=newnode();
     memset(fail,-1,sizeof(fail));
}
void insert(char x[],int bh){
     int u=root;
     for(int i=0;x[i];i++){
        int num=x[i];
        if(trie[u][num]==-1){
            trie[u][num]=newnode();
        }
        u=trie[u][num];
     }
     val[u]=bh;
     return ;
}
void getfail(){
     queue<int>q;
     q.push(root);
     while(!q.empty()){
          int u=q.front();q.pop();
          //cout<<u<<endl;
          for(int i=0;i<siz;i++){
              int num=i;

              if(trie[u][num]!=-1){
              if(u==root){
                 fail[trie[u][num]]=root;
                 //q.push(trie[u][num]);
              }
              else{
              int temp=fail[u];
              while(temp!=-1){
                  if(trie[temp][num]!=-1){
                      fail[trie[u][num]]=trie[temp][num];
                      break;
                  }
                  temp=fail[temp];
                  //cout<<"**i know"<<temp<<"**"<<num<<endl;
              }
              if(temp==-1)fail[trie[u][num]]=root;
              }
              q.push(trie[u][num]);

              }
          }
     }
}
int query(char xx[]){
      int u=root;
      memset(vis,0,sizeof(vis));
      //int ans=0;
      for(int i=0;xx[i];i++){
          int num=xx[i];
          //cout<<i<<endl;
          while(u!=0&&trie[u][num]==-1){
              u=fail[u];
          }
          u=trie[u][num];
          if(u==-1)u=0;
          int temp=u;
          while(temp!=0){
               //cout<<temp<<"**"<<i<<endl;
               if(val[temp]>0){
                  //ans+=val[temp];
                   vis[val[temp]]++;
                   //cout<<val[temp]<<endl;
                   //val[temp]=0;
                   //cout<<val[temp]<<"**"<<endl;
               }

               //else break;
               temp=fail[temp];
          }
      }
      bool flag=false;
     for(int i=1;i<=m;i++){
         if(vis[i]){
            flag=true;
            printf("%s: %d\n",x[i],vis[i]);
         }
     }
     //printf("\n");

            return 0;
}
int main()
{
    while(~scanf("%d",&m)){
    init();
    memset(val,0,sizeof(val));
    for(int i=1;i<=m;i++){
        scanf("%s",x[i]);
        insert(x[i],i);
    }
    //cout<<trie[0][100]<<endl;
    getfail();
    //for()
     //cout<<"??"<<endl;
      //int x=trie[root][(int)('a')];
      //int y=trie[x][(int)('b')];
      //cout<<fail[y]<<endl;
      //cout<<trie[root][(int)('b')]<<endl;
      //cout<<val[fail[y]]<<endl;
       scanf("%s",s);


       query(s);
    }
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章