poj 2778

這道題類似於poj 1625,只不過字符集變得更小,只有’A’,’G’,’C’,’T’四個字符。相應的,序列的長度大幅度的增加至2*10^9。假設序列長度爲m,總的trie圖節點數爲sz,原先的dp方法的複雜度爲O(m*sz^2)。若仍採用dp的方法,那麼總的操作數達到10^11,肯定不可取。因此需要其他的方法。矩陣可以解決這個問題。求出trie圖的鄰接矩陣,兩點之間的權值爲這兩點間的邊數,矩陣中的a[i][j]表示由節點i,到節點j走一步的方法數。此矩陣的n次冪後a’[i][j]即爲由節點i走n步到節點j的方法數。因此題目的答案a[root][i]的和,(0<=i<sz)。同樣,要記得去除所有危險節點的行和列,即權值標爲0,表示沒有這條邊。
矩陣的計算過程採用快速冪。複雜度爲O(log(n))
更多的解釋可以參見poj 1625
下面貼代碼:

#include <cstdio>
#include <queue>
#include <cstring>
#define sigma_size 4
#define maxnode 105
#define modnum 100000
using namespace std;

int m, n;

int ch[maxnode][sigma_size];
int val[maxnode];
int sz;
int f[maxnode];
int last[maxnode];

void initial()
{
    sz = 1;
    memset(ch[0], 0, sizeof(ch[0]));
}
int charget(char a)
{
    switch(a)
    {
    case 'A':
        return 0;
    case 'G':
        return 1;
    case 'C':
        return 2;
    case 'T':
        return 3;
    }
    return -1;
}
void insert(char *S)
{
    int l = strlen(S);
    int u = 0;
    for(int i = 0; i < l; i++)
    {
        int c = charget(S[i]);
        if(!ch[u][c])
        {
            memset(ch[sz], 0, sizeof(ch[sz]));
            val[sz] = 0;
            ch[u][c] = sz++;
        }
        u = ch[u][c];
    }
    val[u] = 1;
}
void getfail()
{
    queue<int>q;
    f[0] = 0;
    for(int c = 0; c < sigma_size; c++)
    {
        int u = ch[0][c];
        if(u)
        {
            f[u] = 0;
            q.push(u);
            last[u] = 0;
        }
    }
    while(!q.empty())
    {
        int r = q.front();
        q.pop();
        for(int c = 0; c < sigma_size; c++)
        {
            int u = ch[r][c];
            if(!u)
            {
                ch[r][c] = ch[f[r]][c];
                continue;
            }
            q.push(u);
            int v = f[r];
            while(v && !ch[v][c]) v = f[v];
            f[u] = ch[v][c];
            last[u] = val[f[u]] ? f[u] : last[f[u]];
        }
    }
}
long long matrix[maxnode][maxnode];
long long res1[maxnode][maxnode];
long long res2[maxnode][maxnode];
void buildmatrix()
{
    memset(matrix, 0, sizeof(matrix));
    for(int i = 0; i < sz; i++)
    {
        if(val[i] || last[i]) //去除非法節點出邊
            continue;
        for(int j = 0 ; j < sigma_size; j++)
        {
            if(val[ch[i][j]] || last[ch[i][j]])
                continue;//去除非法節點入邊
            matrix[i][ch[i][j]]++;
        }
    }
}
void mul(long long a[][maxnode], long long b[][maxnode], long long c[][maxnode])
{
    memset(c, 0, sizeof(matrix));

    for(int i = 0; i < sz; i++)
        for(int j = 0; j < sz; j++)
        {
            for(int k = 0; k < sz; k++)
                c[i][j] += (a[i][k]*b[k][j]);
            c[i][j] %= modnum;
        }

}
void swap(long long a[][maxnode],long long b[][maxnode])
{
    long long tmp;
    for(int i = 0; i <sz;i++)
        for(int j = 0; j<sz;j++)
        {
            tmp = a[i][j];
            a[i][j] = b[i][j];
            b[i][j] = tmp;
        }
}
void multiple(int n)
{
    if(n == 1)
    {
        for(int i = 0; i < sz; i++)
            for(int j = 0; j < sz; j++)
                res1[i][j] = matrix[i][j];
        return;
    }
    multiple(n / 2);
    mul(res1, res1, res2);
    if(n % 2)
    {
        mul(res2, matrix, res1);
    }
    else
        swap(res2, res1);
}

char ban[12];
int main()
{
    initial();
    scanf("%d%d", &m, &n);
    for(int i = 0; i < m; i++)
    {
        scanf("%s", ban);
        insert(ban);
    }
    getfail();
    buildmatrix();
    multiple(n);
    long long ans = 0;

    for(int i = 0; i < sz; i++)
        ans += res1[0][i];
    printf("%lld\n",ans%modnum);
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章