POJ 2778 DNA Sequence
題意:給你個病毒的DNA序列,現在要造出一個長度爲的DNA序列,問你有多少種不含病毒DNA序列的方案。
首先可以看到要構造的序列長度很大,達到了2e9(20億),遍歷一遍都會超時,肯定得寫一個時間複雜度在以下的算法。
怎麼解決這一問題呢?先直接說結論吧:
用AC自動機構造出鄰接矩陣,然後跑矩陣快速冪,最後取矩陣第0行元素之和即可。
看了這句話是不是一頭霧水?(我也一樣) 下面來分析一下爲什麼要構造鄰接矩陣。
在離散數學中有這樣一個結論:
說人話 就是:從點到點恰好經過步的方案數,爲鄰接矩陣的次冪得到的矩陣(假設是)中的元素。(具體解法詳見這篇文章)
那麼這一結論對本題有什麼啓示呢?
所謂構造一個序列,其實就是讓一個點從根節點開始走,保證走到的第一個點是序列的第一個元素,第二個點是第二個元素,走的過程實際上就是在Trie圖中進行狀態的轉移。先看看暴力的想法:直接讓一個點從根節點出發,在Trie圖中“隨便亂走”,由於Trie圖在Trie樹的基礎上補全了不存在的出邊節點, 那麼每個點在下一步都有四個點(A,T,C,G)的選擇,走到m步就停止,當然走的過程中不能經過病毒串終點,這樣就是合法的序列。但是之前已經說過,不可能走m步,因爲m太大了。
一次性走m步求不出,但是可以求每個點只走一步能轉移到哪些點,這實際上就是求鄰接矩陣。
如果我們能求出鄰接矩陣,那麼再求出,就得到了走m步能到達的點的所有情況。在矩陣中,第0行,第 j 列元素表示從0點開始走m步走到 j 點可能的方案數,求和即可。
注意ac自動機中的cnt[]成爲病毒串標記,在構造fail指針時記得將標記向下傳遞。
AC代碼:
#include <iostream>
#include <cstdio>
#include <queue>
#include <cstring>
using namespace std;
typedef long long ll;
const int N=105,M=4,K=10,mod=1e5;
int to_int(char c) // A,T,C,G -> 0,1,2,3
{
if(c=='A')return 0;
else if(c=='T')return 1;
else if(c=='C')return 2;
else return 3;
}
struct matrix
{
ll m[N][N];
matrix() // 構造函數,初始化
{
memset(m,0,sizeof(m));
};
};
struct trie
{
int ch[N][M];
int fail[N];
bool cnt[N];
int tot;
queue<int>q;
void ins(char s[])
{
int u=0;
for(int i=0;s[i];i++)
{
int x=to_int(s[i]);
if(!ch[u][x])ch[u][x]=++tot;
u=ch[u][x];
}
// 病毒串終點標記(題目應該保證了不會出現兩個相同病毒串)
cnt[u]=1;
}
void build_fail()
{
for(int i=0;i<M;i++)
{
if(ch[0][i])
q.push(ch[0][i]);
}
while(!q.empty())
{
int u=q.front();q.pop();
for(int i=0;i<M;i++)
{
int &v=ch[u][i];
int f=ch[fail[u]][i];
if(v)
{
fail[v]=f;
cnt[v]|=cnt[f];// 等價於 if(cnt[f])cnt[v]=1;
// 病毒串終點標記向下傳遞(不要寫反了!是f傳遞到v!)
q.push(v);
}
else v=f;
}
}
}
matrix build_matrix() // 構建鄰接矩陣
{
matrix ans=matrix();
for(int i=0;i<=tot;i++) // tot+1個點
{
if(cnt[i])continue; // u不能是病毒串終點
for(int j=0;j<M;j++) // 每個點有M條出邊
{
int v=ch[i][j]; // 走到的下一節點爲v
if(!cnt[v]) // v不能是病毒串終點
ans.m[i][v]++;
}
}
return ans;
}
}ac;
matrix mul(matrix s1,matrix s2) // 兩矩陣相乘
{
matrix ans=matrix();
int sz=ac.tot+1;
for(int i=0;i<sz;i++)
{
for(int j=0;j<sz;j++)
{
for(int k=0;k<sz;k++)
{
ans.m[i][j]+=s1.m[i][k]*s2.m[k][j];
// ans.m[i][j]=(ans.m[i][j]+s1.m[i][k]*s2.m[k][j]%mod)%mod;
// 取模這樣寫容易超時,在保證不會爆long long的情況下
// 應該先加起來存long long裏,最後再取模
}
ans.m[i][j]%=mod;
}
}
return ans;
}
matrix matrix_pow(matrix a,int b) // 矩陣a的b次冪
{
matrix ans=matrix();
int sz=ac.tot+1;
for(int i=0;i<sz;i++)
ans.m[i][i]=1; // 單位矩陣
while(b)
{
if(b&1)ans=mul(ans,a);
b/=2;
a=mul(a,a);
}
return ans;
}
int n,m;
char t[K];
int main()
{
ios::sync_with_stdio(false);
cin>>n>>m;
for(int i=1;i<=n;i++)
{
cin>>t;
ac.ins(t);
}
ac.build_fail();
matrix a=ac.build_matrix(); // 得到鄰接矩陣a
matrix ans=matrix_pow(a,m); // 得到矩陣a^m
ll sum=0;
for(int i=0;i<=ac.tot;i++) // 累加 0 —> 0,1,...tot 的所有方案數
sum=(sum+ans.m[0][i])%mod;
printf("%lld\n",sum);
return 0;
}
/*
2 32
A
T
ans:67296
10 100
AGAGAGT
CGTATTG
AAAATTTCGC
GCGTA
TCGA
AATTGGA
TAGATAGC
AGCGTATT
TTCGA
TACGTATTG
ans:35771
*/
HDU 2243 考研路茫茫——單詞情結
和上題差不多,這個是要求<=m的所有方案,構造一個矩陣[{E,E},{0,A}]進行快速冪即可得到A0+A1+…+Am。
然後對264取模的意思就是定義成unsigned long long,計算過程中會自動對264取模 (我先還以爲是大數取模呢)
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long ll;
const int N=105,M=26,K=5;
struct matrix
{
ll m[N][N];
matrix() // 構造函數,初始化
{
memset(m,0,sizeof(m));
};
};
struct trie
{
int ch[N][M];
bool cnt[N];
int fail[N];
int tot;
queue<int>q;
void init()
{
tot=0;
memset(cnt,0,sizeof(cnt));
memset(fail,0,sizeof(fail));
memset(ch,0,sizeof(ch));
}
void ins(char s[])
{
int u=0;
for(int i=0;s[i];i++)
{
int x=s[i]-'a'; // a~z -> 0~25
if(!ch[u][x])ch[u][x]=++tot;
u=ch[u][x];
}
cnt[u]=1;
}
void build_fail()
{
for(int i=0;i<M;i++)
{
if(ch[0][i])
q.push(ch[0][i]);
}
while(!q.empty())
{
int u=q.front();q.pop();
for(int i=0;i<M;i++)
{
int &v=ch[u][i];
int f=ch[fail[u]][i];
if(v)
{
fail[v]=f;
cnt[v]|=cnt[f];
q.push(v);
}
else v=f;
}
}
}
matrix build_matrix() // 得到鄰接矩陣
{
matrix ans=matrix();
for(int i=0;i<=tot;i++)
{
if(cnt[i])continue;
for(int j=0;j<M;j++)
{
int v=ch[i][j];
if(!cnt[v])
ans.m[i][v]++;
}
}
return ans;
}
}ac;
matrix mul(matrix s1,matrix s2,int sz)
{
matrix ans=matrix();
for(int i=0;i<sz;i++)
for(int j=0;j<sz;j++)
for(int k=0;k<sz;k++)
ans.m[i][j]+=s1.m[i][k]*s2.m[k][j];// 不用取模!
return ans;
}
matrix matrix_pow(matrix a,int b,int sz)
{
matrix ans=matrix();
for(int i=0;i<sz;i++)
ans.m[i][i]=1; // 單位矩陣
while(b)
{
if(b&1)ans=mul(ans,a,sz);
a=mul(a,a,sz);
b/=2;
}
return ans;
}
int n,m;
char t[K];
int main()
{
ios::sync_with_stdio(false);
while(cin>>n>>m)
{
ac.init();
for(int i=1;i<=n;i++)
{
cin>>t;
ac.ins(t);
}
ac.build_fail();
matrix tmp=matrix();
tmp.m[0][0]=1,tmp.m[0][1]=1,tmp.m[1][1]=26;
matrix s1=matrix_pow(tmp,m+1,2);
ll sum1=s1.m[0][1]-1; // 總方案數sum1
matrix a=ac.build_matrix(); // 鄰接矩陣
int sz=ac.tot+1;
matrix b=matrix();
for(int i=0;i<sz;i++)
{
b.m[i][i]=1;
b.m[i][i+sz]=1;
}
for(int i=0;i<sz;i++)
for(int j=0;j<sz;j++)
b.m[i+sz][j+sz]=a.m[i][j];
matrix s2=matrix_pow(b,m+1,2*sz);
ll sum2=0;
for(int i=sz;i<2*sz;i++)
sum2+=s2.m[0][i];
sum2--; //不含模式串的方案數sum2
//printf("sum1=%I64u sum2=%I64u\n",sum1,sum2); // debug
printf("%I64u\n",sum1-sum2);
}
return 0;
}
/*
2 3
aa
ab
sum1=18278 sum2=18174
ans:104
2 13
aa
ab
sum1=2580398988131886038 sum2=2493353857086648626
ans:87045131045237412
2 2000000000
aa
ab
sum1=8116567392432202710 sum2=14915077526685486680
ans:11648233939456267646
*/
POJ 1625 Censored!
這題需要用到高精度,然後因爲冪次比較小可以不用快速冪,直接DP就行,我想着java有大數,那就拿java寫一下大數的矩陣快速冪吧。
注意編碼格式,因爲java沒有unsigned char。將輸入寫成Scanner cin = new Scanner(new BufferedInputStream(System.in), "ISO-8859-1");
即可。
Java代碼:
//package Main; //package信息一定要去掉,否則RE
import java.io.BufferedInputStream;
import java.math.BigInteger;
import java.util.LinkedList;
import java.util.Queue;
import java.util.Scanner;
class Matrix {
BigInteger m[][];
Matrix(int sz, int type) { // type控制零矩陣/單位矩陣
m = new BigInteger[sz][sz];
for(int i = 0; i < sz; i++)
for(int j = 0; j < sz; j++)
m[i][j] = BigInteger.ZERO;
if(type == 1) {
for(int i = 0; i < sz; i++)
m[i][i] = BigInteger.ONE;
}
}
}
class Trie {
static final int N = 105, M = 55, K = 256;
int ch[][] = new int[N][M];
int fail[] = new int[N];
boolean cnt[] = new boolean[N];
int tot = 0;
int len;
Queue<Integer> q = new LinkedList<Integer>();
int mp[] = new int[K];
void ins(String s) {
int u = 0;
for(int i = 0; i < s.length(); i++)
{
int x=mp[s.charAt(i)]; // ASCII碼 -> 0~len-1(len是字母表長度)
if(ch[u][x] == 0) ch[u][x] = ++tot;
u = ch[u][x];
}
cnt[u] = true;
}
void build_fail() {
for(int i = 0; i < len; i++) {
if(ch[0][i] != 0)
q.offer(ch[0][i]);
}
while(!q.isEmpty()) { // 隊列非空
int u = q.poll(); // 取出並刪除隊頭的元素
for(int i = 0; i < len; i++) {
int v = ch[u][i];
int f = ch[fail[u]][i];
if(v != 0) {
if(cnt[f] == true) cnt[v] = true;
fail[v] = f;
q.offer(v);
}
else ch[u][i] = f;
}
}
}
Matrix build_matrix() {
int sz = tot + 1;
Matrix ans = new Matrix(sz, 0);
for(int i = 0; i < sz; i++) {
if(cnt[i]) continue;
for(int j = 0;j < len; j++) {
int v = ch[i][j];
if(!cnt[v])
ans.m[i][v] = ans.m[i][v].add(BigInteger.ONE);
}
}
return ans;
}
Matrix mul(Matrix s1, Matrix s2) {
int sz = tot + 1;
Matrix ans = new Matrix(sz, 0);
for(int i = 0; i < sz; i++)
for(int j = 0; j < sz; j++)
for(int k = 0; k < sz; k++)
ans.m[i][j] = ans.m[i][j].add(s1.m[i][k].multiply(s2.m[k][j]));
return ans;
}
Matrix matrix_pow(Matrix a, int b) {
int sz = tot + 1;
Matrix ans = new Matrix(sz, 1);
while(b != 0) {
if(b%2 == 1) ans = mul(ans,a);
a = mul(a,a);
b /= 2;
}
return ans;
}
}
public class Main {
public static void main(String[] args) {
Scanner cin = new Scanner(new BufferedInputStream(System.in), "ISO-8859-1");
//Scanner cin = new Scanner(System.in); 會RE
Trie ac = new Trie();
ac.len = cin.nextInt();
int m = cin.nextInt();
int n = cin.nextInt();
String s = cin.next();
for(int i = 0; i < ac.len; i++) {
ac.mp[s.charAt(i)] = i;
}
for(int i = 1; i <= n; i++) {
String t = cin.next(); // 病毒串
ac.ins(t);
}
ac.build_fail();
Matrix a = ac.build_matrix(); // 鄰接矩陣
Matrix ans = ac.matrix_pow(a,m);
BigInteger sum = BigInteger.ZERO;
int sz = ac.tot + 1;
for(int i = 0; i < sz; i++) {
sum = sum.add(ans.m[0][i]);
}
System.out.println(sum);
}
}
C++代碼(沒寫大數,不能AC,只是作爲java代碼的“翻譯”):
#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>
#include <cmath>
using namespace std;
typedef unsigned long long ll;
typedef unsigned char uc;
const int N=105,M=55,K=256;
int mp[K];
int len,n,m;
struct matrix
{
ll m[N][N];
matrix()
{
memset(m,0,sizeof(m));
}
};
struct trie
{
int ch[N][M];
int fail[N];
bool cnt[N];
int tot;
queue<int>q;
void init()
{
memset(cnt,0,sizeof(cnt));
memset(fail,0,sizeof(fail));
memset(ch,0,sizeof(ch));
tot=0;
}
void ins(uc s[])
{
int u=0;
for(int i=0;s[i];i++)
{
int x=mp[s[i]]; // ASCII碼 -> 0~len-1(len是字母表長度)
if(!ch[u][x])ch[u][x]=++tot;
u=ch[u][x];
}
cnt[u]=1;
}
void build_fail()
{
for(int i=0;i<len;i++)
{
if(ch[0][i])
q.push(ch[0][i]);
}
while(!q.empty())
{
int u=q.front();q.pop();
for(int i=0;i<len;i++)
{
int &v=ch[u][i];
int f=ch[fail[u]][i];
if(v)
{
cnt[v]|=cnt[f];
fail[v]=f;
q.push(v);
}
else v=f;
}
}
}
matrix build_matrix()
{
int sz=tot+1;
matrix ans=matrix();
for(int i=0;i<sz;i++)
{
if(cnt[i])continue;
for(int j=0;j<len;j++)
{
int v=ch[i][j];
if(!cnt[v])
ans.m[i][v]++;
}
}
return ans;
}
}ac;
matrix mul(matrix s1,matrix s2)
{
int sz=ac.tot+1;
matrix ans=matrix();
for(int i=0;i<sz;i++)
for(int j=0;j<sz;j++)
for(int k=0;k<sz;k++)
ans.m[i][j]+=s1.m[i][k]*s2.m[k][j];
return ans;
}
matrix matrix_pow(matrix a,int b)
{
int sz=ac.tot+1;
matrix ans=matrix();
for(int i=0;i<sz;i++)
ans.m[i][i]=1;
while(b)
{
if(b&1)ans=mul(ans,a);
a=mul(a,a);
b/=2;
}
return ans;
}
int main()
{
ios::sync_with_stdio(false);
uc c,t[12];
cin>>len>>m>>n;
memset(mp,0,sizeof(mp));
ac.init();
for(int i=0;i<len;i++)
{
cin>>c;
mp[c]=i;
}
for(int i=1;i<=n;i++)
{
cin>>t;
ac.ins(t);
}
ac.build_fail();
int sz=ac.tot+1;
matrix a=ac.build_matrix();
matrix ans=matrix_pow(a,m);
ll sum=0;
for(int i=0;i<sz;i++)
sum+=ans.m[0][i];
printf("%llu",sum);
return 0;
}