kuangbin專題十七 HDU2243 (經典好題) AC自動機+矩陣快速冪

題意:

題解:
這道題跟POJ2278相似,POJ2278求的是不包含,這道題求的是包含,那麼我們就算出總和減去不包含的數量就可以得出包含的數量了。思路是從博客上學到的,但是laoda的代碼,可能是我太垃圾了,看着不太理解的了,所以還是自己折騰弄了一下午,終於弄出來了,折騰一下午的原因是flag和fail看錯。 。。MD。。
以下是給我思路和公式的博客:
http://www.xuebuyuan.com/1154746.html
http://blog.csdn.net/runner__/article/details/51394417#
http://blog.csdn.net/alxpcun/article/details/51217002
http://blog.csdn.net/baidu_23081367/article/details/52347256
這裏解釋一下unsigned long long爲什麼超過2^64會自動mod2^64,因爲 unsigned long long的範圍爲(0,2^64-1)而且是沒負數的,所以當你到達2^64的時候就會自動轉化爲0。具體看一下下面這篇博客:
http://blog.csdn.net/u014800748/article/details/45439857
另外那個矩陣的公式大家去看上面那4個博客就應該能弄懂了,記得在你得出矩陣之後一定要在旁邊加1列 1,就相當於弄了個E進去矩陣中了。
然後感覺詫異的就是爲什麼不直接算出包含的字符串呢?我自己的理解就是AC自動機自身的特點決定了要求出不包含的數量再讓總數減去。因爲你AC自動機有一個失配指針,可以使你在某個字母不匹配的時候轉移去另外一個具有公共後綴的匹配串中,這樣得出來的不匹配字符串速度快很多,而你AC自動機匹配成功之後是不會轉移的,所以很難求出長度爲m的包含匹配串的字符串,因爲你長度爲m的字符串有26^m種變化啊,求到你吐血啊。不知道本菜的解釋合不合理,希望有dalao指出我的不足和不對,免得誤導了大家,謝謝了。

#include<stdio.h>
#include<string.h>
#include<queue>
#include<algorithm>
using namespace std;
#define ull unsigned long long int 
struct Trie
{
    bool flag;
    int id;
    Trie *fail;
    Trie *next[26];
    Trie(){
        id=0;
        flag=false;
        fail=0;
        for(int i=0;i<26;i++)
        next[i]=0;
    }
}*root;
struct node
{
    ull m[50][50];
    node(){
        memset(m,0,sizeof(m));
    }
};
int tot;
Trie *dfn[50];
void insert(char *s)
{
    Trie *r=root;
    for(int i=0;s[i];i++)
    {
        int x=s[i]-'a';
        if(r->next[x]==0){
            r->next[x]=new(Trie);
            r=r->next[x];   
            r->id=tot;
            dfn[tot++]=r;
        }
        else
        r=r->next[x];
    }
    r->flag=true;
}
void getfail()//這裏要改點東西 
{
    queue<Trie *>q;
    q.push(root);
    Trie *p;
    Trie *temp;
    while(!q.empty())
    {
        temp=q.front();
        q.pop();
        for(int i=0;i<26;i++)
        {
            if(temp->next[i])
            {
                if(temp==root)
                {
                    temp->next[i]->fail=root;
                }
                else
                {
                    p=temp->fail;
                    while(p)
                    {
                        if(p->next[i])
                        {
                            temp->next[i]->fail=p->next[i];
                            break;
                        }
                        p=p->fail;
                    }
                    if(p==0)
                    temp->next[i]->fail=root;
                }
                if(temp->next[i]->fail->flag)//兒子是危險串時,自己也應該爲危險串,就相當於C是危險字符串,AC中A不是危險的,但是匹配了C成爲了危險的。 
                temp->next[i]->flag=true;//或者可以這樣理解:表示這個前綴是詞根,如acg,ac.   
                q.push(temp->next[i]);
            }
            else//這一步也是關鍵,相當於把空的節點補充完整,使每個子串都有兒子
            {
                if(temp->fail) temp->next[i]=temp->fail->next[i];
                else temp->next[i]=root;
            } 
        }
    }
}

node getMatrix()
{
    node A;
    for(int i=0;i<tot;i++)
        for(int j=0;j<26;j++)
        if(!dfn[i]->flag&&!dfn[i]->next[j]->flag)
        A.m[dfn[i]->id][dfn[i]->next[j]->id]++;
    for(int i=0;i<=tot;i++)//在原矩陣中加一列1,就相當於弄了E進去矩陣中。 
    A.m[i][tot]=1;
    return A;       
}
node cla(node A,node B)
{
    node C;
    for(int i=0;i<=tot;i++)
        for(int j=0;j<=tot;j++)
            for(int k=0;k<=tot;k++)
            if(A.m[i][k]&&B.m[k][j])
            {
                C.m[i][j]+=A.m[i][k]*B.m[k][j];
            } 
    return C;       
}
node POW(node A,int k,int n)
{
    node C;
    for(int i=0;i<=n;i++) C.m[i][i]=1;
    while(k)
    {
        if(k&1) C=cla(C,A);
        A=cla(A,A);
        k>>=1;
    } 
    return C; 
}
int main()
{
    int n,m;
    while(~scanf("%d%d",&n,&m))
    {
        root=new(Trie);
        tot=0;
        dfn[tot++]=root;
        char c[10];
        for(int i=0;i<n;i++)
        {
            scanf("%s",c);
            insert(c);
        }
        getfail();
        node A=getMatrix();
//      for(int i=0;i<=tot;i++)
//      {
//          for(int j=0;j<=tot;j++)
//          printf("%llu ",A.m[i][j]);
//          printf("\n");
//      }
        A=POW(A,m,tot+1);//這裏計算的是A^0+A^1 + A^2 + A^3 + ... + A^m. 
        ull ans1=0;
        for(int i=0;i<=tot;i++)
        ans1+=A.m[0][i];
        node B;
        B.m[0][0]=26,B.m[0][1]=B.m[1][1]=1,B.m[1][0]=0;
        B=POW(B,m,2);//這裏計算的是26^0+26^1 + 26^2 + 26^3 + ... +26^m. 
        ull ans2=0;
        for(int i=0;i<2;i++) 
        ans2+=B.m[0][i];
//      printf("%llu\n",ans2);
        printf("%llu\n",ans2-ans1);
    }
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章