poj 2778

这道题类似于poj 1625,只不过字符集变得更小,只有’A’,’G’,’C’,’T’四个字符。相应的,序列的长度大幅度的增加至2*10^9。假设序列长度为m,总的trie图节点数为sz,原先的dp方法的复杂度为O(m*sz^2)。若仍采用dp的方法,那么总的操作数达到10^11,肯定不可取。因此需要其他的方法。矩阵可以解决这个问题。求出trie图的邻接矩阵,两点之间的权值为这两点间的边数,矩阵中的a[i][j]表示由节点i,到节点j走一步的方法数。此矩阵的n次幂后a’[i][j]即为由节点i走n步到节点j的方法数。因此题目的答案a[root][i]的和,(0<=i<sz)。同样,要记得去除所有危险节点的行和列,即权值标为0,表示没有这条边。
矩阵的计算过程采用快速幂。复杂度为O(log(n))
更多的解释可以参见poj 1625
下面贴代码:

#include <cstdio>
#include <queue>
#include <cstring>
#define sigma_size 4
#define maxnode 105
#define modnum 100000
using namespace std;

int m, n;

int ch[maxnode][sigma_size];
int val[maxnode];
int sz;
int f[maxnode];
int last[maxnode];

void initial()
{
    sz = 1;
    memset(ch[0], 0, sizeof(ch[0]));
}
int charget(char a)
{
    switch(a)
    {
    case 'A':
        return 0;
    case 'G':
        return 1;
    case 'C':
        return 2;
    case 'T':
        return 3;
    }
    return -1;
}
void insert(char *S)
{
    int l = strlen(S);
    int u = 0;
    for(int i = 0; i < l; i++)
    {
        int c = charget(S[i]);
        if(!ch[u][c])
        {
            memset(ch[sz], 0, sizeof(ch[sz]));
            val[sz] = 0;
            ch[u][c] = sz++;
        }
        u = ch[u][c];
    }
    val[u] = 1;
}
void getfail()
{
    queue<int>q;
    f[0] = 0;
    for(int c = 0; c < sigma_size; c++)
    {
        int u = ch[0][c];
        if(u)
        {
            f[u] = 0;
            q.push(u);
            last[u] = 0;
        }
    }
    while(!q.empty())
    {
        int r = q.front();
        q.pop();
        for(int c = 0; c < sigma_size; c++)
        {
            int u = ch[r][c];
            if(!u)
            {
                ch[r][c] = ch[f[r]][c];
                continue;
            }
            q.push(u);
            int v = f[r];
            while(v && !ch[v][c]) v = f[v];
            f[u] = ch[v][c];
            last[u] = val[f[u]] ? f[u] : last[f[u]];
        }
    }
}
long long matrix[maxnode][maxnode];
long long res1[maxnode][maxnode];
long long res2[maxnode][maxnode];
void buildmatrix()
{
    memset(matrix, 0, sizeof(matrix));
    for(int i = 0; i < sz; i++)
    {
        if(val[i] || last[i]) //去除非法节点出边
            continue;
        for(int j = 0 ; j < sigma_size; j++)
        {
            if(val[ch[i][j]] || last[ch[i][j]])
                continue;//去除非法节点入边
            matrix[i][ch[i][j]]++;
        }
    }
}
void mul(long long a[][maxnode], long long b[][maxnode], long long c[][maxnode])
{
    memset(c, 0, sizeof(matrix));

    for(int i = 0; i < sz; i++)
        for(int j = 0; j < sz; j++)
        {
            for(int k = 0; k < sz; k++)
                c[i][j] += (a[i][k]*b[k][j]);
            c[i][j] %= modnum;
        }

}
void swap(long long a[][maxnode],long long b[][maxnode])
{
    long long tmp;
    for(int i = 0; i <sz;i++)
        for(int j = 0; j<sz;j++)
        {
            tmp = a[i][j];
            a[i][j] = b[i][j];
            b[i][j] = tmp;
        }
}
void multiple(int n)
{
    if(n == 1)
    {
        for(int i = 0; i < sz; i++)
            for(int j = 0; j < sz; j++)
                res1[i][j] = matrix[i][j];
        return;
    }
    multiple(n / 2);
    mul(res1, res1, res2);
    if(n % 2)
    {
        mul(res2, matrix, res1);
    }
    else
        swap(res2, res1);
}

char ban[12];
int main()
{
    initial();
    scanf("%d%d", &m, &n);
    for(int i = 0; i < m; i++)
    {
        scanf("%s", ban);
        insert(ban);
    }
    getfail();
    buildmatrix();
    multiple(n);
    long long ans = 0;

    for(int i = 0; i < sz; i++)
        ans += res1[0][i];
    printf("%lld\n",ans%modnum);
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章