這道題類似於poj 1625,只不過字符集變得更小,只有’A’,’G’,’C’,’T’四個字符。相應的,序列的長度大幅度的增加至2*10^9。假設序列長度爲m,總的trie圖節點數爲sz,原先的dp方法的複雜度爲O(m*sz^2)。若仍採用dp的方法,那麼總的操作數達到10^11,肯定不可取。因此需要其他的方法。矩陣可以解決這個問題。求出trie圖的鄰接矩陣,兩點之間的權值爲這兩點間的邊數,矩陣中的a[i][j]表示由節點i,到節點j走一步的方法數。此矩陣的n次冪後a’[i][j]即爲由節點i走n步到節點j的方法數。因此題目的答案a[root][i]的和,(0<=i<sz)。
同樣,要記得去除所有危險節點的行和列,即權值標爲0,表示沒有這條邊。
矩陣的計算過程採用快速冪。複雜度爲O(log(n))
更多的解釋可以參見poj 1625
下面貼代碼:
#include <cstdio>
#include <queue>
#include <cstring>
#define sigma_size 4
#define maxnode 105
#define modnum 100000
using namespace std;
int m, n;
int ch[maxnode][sigma_size];
int val[maxnode];
int sz;
int f[maxnode];
int last[maxnode];
void initial()
{
sz = 1;
memset(ch[0], 0, sizeof(ch[0]));
}
int charget(char a)
{
switch(a)
{
case 'A':
return 0;
case 'G':
return 1;
case 'C':
return 2;
case 'T':
return 3;
}
return -1;
}
void insert(char *S)
{
int l = strlen(S);
int u = 0;
for(int i = 0; i < l; i++)
{
int c = charget(S[i]);
if(!ch[u][c])
{
memset(ch[sz], 0, sizeof(ch[sz]));
val[sz] = 0;
ch[u][c] = sz++;
}
u = ch[u][c];
}
val[u] = 1;
}
void getfail()
{
queue<int>q;
f[0] = 0;
for(int c = 0; c < sigma_size; c++)
{
int u = ch[0][c];
if(u)
{
f[u] = 0;
q.push(u);
last[u] = 0;
}
}
while(!q.empty())
{
int r = q.front();
q.pop();
for(int c = 0; c < sigma_size; c++)
{
int u = ch[r][c];
if(!u)
{
ch[r][c] = ch[f[r]][c];
continue;
}
q.push(u);
int v = f[r];
while(v && !ch[v][c]) v = f[v];
f[u] = ch[v][c];
last[u] = val[f[u]] ? f[u] : last[f[u]];
}
}
}
long long matrix[maxnode][maxnode];
long long res1[maxnode][maxnode];
long long res2[maxnode][maxnode];
void buildmatrix()
{
memset(matrix, 0, sizeof(matrix));
for(int i = 0; i < sz; i++)
{
if(val[i] || last[i]) //去除非法節點出邊
continue;
for(int j = 0 ; j < sigma_size; j++)
{
if(val[ch[i][j]] || last[ch[i][j]])
continue;//去除非法節點入邊
matrix[i][ch[i][j]]++;
}
}
}
void mul(long long a[][maxnode], long long b[][maxnode], long long c[][maxnode])
{
memset(c, 0, sizeof(matrix));
for(int i = 0; i < sz; i++)
for(int j = 0; j < sz; j++)
{
for(int k = 0; k < sz; k++)
c[i][j] += (a[i][k]*b[k][j]);
c[i][j] %= modnum;
}
}
void swap(long long a[][maxnode],long long b[][maxnode])
{
long long tmp;
for(int i = 0; i <sz;i++)
for(int j = 0; j<sz;j++)
{
tmp = a[i][j];
a[i][j] = b[i][j];
b[i][j] = tmp;
}
}
void multiple(int n)
{
if(n == 1)
{
for(int i = 0; i < sz; i++)
for(int j = 0; j < sz; j++)
res1[i][j] = matrix[i][j];
return;
}
multiple(n / 2);
mul(res1, res1, res2);
if(n % 2)
{
mul(res2, matrix, res1);
}
else
swap(res2, res1);
}
char ban[12];
int main()
{
initial();
scanf("%d%d", &m, &n);
for(int i = 0; i < m; i++)
{
scanf("%s", ban);
insert(ban);
}
getfail();
buildmatrix();
multiple(n);
long long ans = 0;
for(int i = 0; i < sz; i++)
ans += res1[0][i];
printf("%lld\n",ans%modnum);
return 0;
}