AC自動機 進階練習 (結合算法:矩陣快速冪/DP/高精度)

POJ 2778 DNA Sequence

題意:給你nn個病毒的DNA序列,現在要造出一個長度爲mm的DNA序列,問你有多少種不含病毒DNA序列的方案。

首先可以看到要構造的序列長度很大,達到了2e9(20億),遍歷一遍都會超時,肯定得寫一個時間複雜度在O(logn)O(logn)以下的算法。

怎麼解決這一問題呢?先直接說結論吧:
用AC自動機構造出鄰接矩陣,然後跑矩陣快速冪,最後取矩陣第0行元素之和即可。

看了這句話是不是一頭霧水?(我也一樣) 下面來分析一下爲什麼要構造鄰接矩陣。

在離散數學中有這樣一個結論:
在這裏插入圖片描述
說人話 就是:從uu點到vv點恰好經過kk步的方案數,爲鄰接矩陣的kk次冪得到的矩陣(假設是ansans)中的元素ans[u][v]ans[u][v]。(具體解法詳見這篇文章

那麼這一結論對本題有什麼啓示呢?
所謂構造一個序列,其實就是讓一個點從根節點開始走,保證走到的第一個點是序列的第一個元素,第二個點是第二個元素,走的過程實際上就是在Trie圖中進行狀態的轉移。先看看暴力的想法:直接讓一個點從根節點出發,在Trie圖中“隨便亂走”,由於Trie圖在Trie樹的基礎上補全了不存在的出邊節點, 那麼每個點在下一步都有四個點(A,T,C,G)的選擇,走到m步就停止,當然走的過程中不能經過病毒串終點,這樣就是合法的序列。但是之前已經說過,不可能走m步,因爲m太大了。

一次性走m步求不出,但是可以求每個點只走一步能轉移到哪些點,這實際上就是求鄰接矩陣。

如果我們能求出鄰接矩陣AA,那麼再求出AmA^m,就得到了走m步能到達的點的所有情況。在AmA^m矩陣中,第0行,第 j 列元素表示從0點開始走m步走到 j 點可能的方案數,求和即可。

注意ac自動機中的cnt[]成爲病毒串標記,在構造fail指針時記得將標記向下傳遞。

AC代碼:

#include <iostream>
#include <cstdio>
#include <queue>
#include <cstring>
using namespace std;
typedef long long ll;
const int N=105,M=4,K=10,mod=1e5;
int to_int(char c) // A,T,C,G -> 0,1,2,3
{
    if(c=='A')return 0;
    else if(c=='T')return 1;
    else if(c=='C')return 2;
    else return 3;
}
struct matrix
{
    ll m[N][N];
    matrix() // 構造函數,初始化
    {
        memset(m,0,sizeof(m));
    };
};
struct trie
{
    int ch[N][M];
    int fail[N];
    bool cnt[N];
    int tot;
    queue<int>q;
    void ins(char s[])
    {
        int u=0;
        for(int i=0;s[i];i++)
        {
            int x=to_int(s[i]);
            if(!ch[u][x])ch[u][x]=++tot;
            u=ch[u][x];
        }
        // 病毒串終點標記(題目應該保證了不會出現兩個相同病毒串)
        cnt[u]=1;
    }
    void build_fail()
    {
        for(int i=0;i<M;i++)
        {
            if(ch[0][i])
                q.push(ch[0][i]);
        }
        while(!q.empty())
        {
            int u=q.front();q.pop();
            for(int i=0;i<M;i++)
            {
                int &v=ch[u][i];
                int f=ch[fail[u]][i];
                if(v)
                {
                    fail[v]=f;
                    cnt[v]|=cnt[f];// 等價於 if(cnt[f])cnt[v]=1;
                    // 病毒串終點標記向下傳遞(不要寫反了!是f傳遞到v!)
                    q.push(v);
                }
                else v=f;
            }
        }
    }
    matrix build_matrix() // 構建鄰接矩陣
    {
        matrix ans=matrix();
        for(int i=0;i<=tot;i++) // tot+1個點
        {
            if(cnt[i])continue; // u不能是病毒串終點
            for(int j=0;j<M;j++) // 每個點有M條出邊
            {
                int v=ch[i][j]; // 走到的下一節點爲v
                if(!cnt[v]) // v不能是病毒串終點
                    ans.m[i][v]++;
            }
        }
        return ans;
    }
}ac;
matrix mul(matrix s1,matrix s2) // 兩矩陣相乘
{
    matrix ans=matrix();
    int sz=ac.tot+1;
    for(int i=0;i<sz;i++)
    {
        for(int j=0;j<sz;j++)
        {
            for(int k=0;k<sz;k++)
            {
                ans.m[i][j]+=s1.m[i][k]*s2.m[k][j];
                // ans.m[i][j]=(ans.m[i][j]+s1.m[i][k]*s2.m[k][j]%mod)%mod; 
                // 取模這樣寫容易超時,在保證不會爆long long的情況下
                // 應該先加起來存long long裏,最後再取模
            }
            ans.m[i][j]%=mod;
        }
    }
    return ans;
}
matrix matrix_pow(matrix a,int b) // 矩陣a的b次冪
{
    matrix ans=matrix();
    int sz=ac.tot+1;
    for(int i=0;i<sz;i++)
        ans.m[i][i]=1; // 單位矩陣
    while(b)
    {
        if(b&1)ans=mul(ans,a);
        b/=2;
        a=mul(a,a);
    }
    return ans;
}
int n,m;
char t[K];
int main()
{
    ios::sync_with_stdio(false);
    cin>>n>>m;
    for(int i=1;i<=n;i++)
    {
        cin>>t;
        ac.ins(t);
    }
    ac.build_fail();
    matrix a=ac.build_matrix(); // 得到鄰接矩陣a
    matrix ans=matrix_pow(a,m); // 得到矩陣a^m
    ll sum=0;
    for(int i=0;i<=ac.tot;i++) // 累加 0 —> 0,1,...tot 的所有方案數
        sum=(sum+ans.m[0][i])%mod;
    printf("%lld\n",sum);
    return 0;
}
/*
2 32
A
T
ans:67296

10 100
AGAGAGT
CGTATTG
AAAATTTCGC
GCGTA
TCGA
AATTGGA
TAGATAGC
AGCGTATT
TTCGA
TACGTATTG
ans:35771
*/

HDU 2243 考研路茫茫——單詞情結

和上題差不多,這個是要求<=m的所有方案,構造一個矩陣[{E,E},{0,A}]進行快速冪即可得到A0+A1+…+Am

然後對264取模的意思就是定義成unsigned long long,計算過程中會自動對264取模 (我先還以爲是大數取模呢)

#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long ll;
const int N=105,M=26,K=5;
struct matrix
{
    ll m[N][N];
    matrix() // 構造函數,初始化
    {
        memset(m,0,sizeof(m));
    };
};
struct trie
{
    int ch[N][M];
    bool cnt[N];
    int fail[N];
    int tot;
    queue<int>q;
    void init()
    {
        tot=0;
        memset(cnt,0,sizeof(cnt));
        memset(fail,0,sizeof(fail));
        memset(ch,0,sizeof(ch));
    }
    void ins(char s[])
    {
        int u=0;
        for(int i=0;s[i];i++)
        {
            int x=s[i]-'a'; // a~z -> 0~25
            if(!ch[u][x])ch[u][x]=++tot;
            u=ch[u][x];
        }
        cnt[u]=1;
    }
    void build_fail()
    {
        for(int i=0;i<M;i++)
        {
            if(ch[0][i])
                q.push(ch[0][i]);
        }
        while(!q.empty())
        {
            int u=q.front();q.pop();
            for(int i=0;i<M;i++)
            {
                int &v=ch[u][i];
                int f=ch[fail[u]][i];
                if(v)
                {
                    fail[v]=f;
                    cnt[v]|=cnt[f];
                    q.push(v);
                }
                else v=f;
            }
        }
    }
    matrix build_matrix() // 得到鄰接矩陣
    {
        matrix ans=matrix();
        for(int i=0;i<=tot;i++)
        {
            if(cnt[i])continue;
            for(int j=0;j<M;j++)
            {
                int v=ch[i][j];
                if(!cnt[v])
                    ans.m[i][v]++;
            }
        }
        return ans;
    }
}ac;
matrix mul(matrix s1,matrix s2,int sz)
{
    matrix ans=matrix();
    for(int i=0;i<sz;i++)
        for(int j=0;j<sz;j++)
            for(int k=0;k<sz;k++)
                ans.m[i][j]+=s1.m[i][k]*s2.m[k][j];// 不用取模!
    return ans;
}
matrix matrix_pow(matrix a,int b,int sz)
{
    matrix ans=matrix();
    for(int i=0;i<sz;i++)
        ans.m[i][i]=1; // 單位矩陣
    while(b)
    {
        if(b&1)ans=mul(ans,a,sz);
        a=mul(a,a,sz);
        b/=2;
    }
    return ans;
}
int n,m;
char t[K];
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>n>>m)
    {
        ac.init();
        for(int i=1;i<=n;i++)
        {
            cin>>t;
            ac.ins(t);
        }
        ac.build_fail();
        
        matrix tmp=matrix();
        tmp.m[0][0]=1,tmp.m[0][1]=1,tmp.m[1][1]=26;
        matrix s1=matrix_pow(tmp,m+1,2);
        ll sum1=s1.m[0][1]-1; // 總方案數sum1
        
        matrix a=ac.build_matrix(); // 鄰接矩陣
        int sz=ac.tot+1;
        matrix b=matrix();
        for(int i=0;i<sz;i++)
        {
            b.m[i][i]=1;
            b.m[i][i+sz]=1;
        }
        for(int i=0;i<sz;i++)
            for(int j=0;j<sz;j++)
                b.m[i+sz][j+sz]=a.m[i][j];
        matrix s2=matrix_pow(b,m+1,2*sz);
        ll sum2=0;
        for(int i=sz;i<2*sz;i++)
            sum2+=s2.m[0][i];
        sum2--; //不含模式串的方案數sum2
        
        //printf("sum1=%I64u sum2=%I64u\n",sum1,sum2); // debug
        printf("%I64u\n",sum1-sum2);
    }
    return 0;
}
/*
2 3
aa
ab
sum1=18278 sum2=18174
ans:104

2 13
aa
ab
sum1=2580398988131886038 sum2=2493353857086648626
ans:87045131045237412

2 2000000000
aa
ab
sum1=8116567392432202710 sum2=14915077526685486680
ans:11648233939456267646
*/

POJ 1625 Censored!

這題需要用到高精度,然後因爲冪次比較小可以不用快速冪,直接DP就行,我想着java有大數,那就拿java寫一下大數的矩陣快速冪吧。

注意編碼格式,因爲java沒有unsigned char。將輸入寫成Scanner cin = new Scanner(new BufferedInputStream(System.in), "ISO-8859-1"); 即可。

Java代碼:

//package Main; //package信息一定要去掉,否則RE
import java.io.BufferedInputStream;
import java.math.BigInteger;
import java.util.LinkedList;
import java.util.Queue;
import java.util.Scanner;

class Matrix {
	
	BigInteger m[][];
	Matrix(int sz, int type) { // type控制零矩陣/單位矩陣
		m = new BigInteger[sz][sz];
		for(int i = 0; i < sz; i++)
			for(int j = 0; j < sz; j++)
				m[i][j] = BigInteger.ZERO;
		if(type == 1) {
			for(int i = 0; i < sz; i++)
				m[i][i] = BigInteger.ONE;
		}
	}
	
}

class Trie {
	
	static final int N = 105, M = 55, K = 256;
	int ch[][] = new int[N][M];
    int fail[] = new int[N];
    boolean cnt[] = new boolean[N];
    int tot = 0;
    int len;
    Queue<Integer> q = new LinkedList<Integer>();
    int mp[] = new int[K];
    
    void ins(String s) {
		int u = 0;
		for(int i = 0; i < s.length(); i++)
        {
			int x=mp[s.charAt(i)]; // ASCII碼 -> 0~len-1(len是字母表長度)
            if(ch[u][x] == 0) ch[u][x] = ++tot;
            u = ch[u][x];
        }
        cnt[u] = true;
	}
    
    void build_fail() {
        for(int i = 0; i < len; i++) {
            if(ch[0][i] != 0) 
            	q.offer(ch[0][i]);
        }
        while(!q.isEmpty()) { // 隊列非空
            int u = q.poll(); // 取出並刪除隊頭的元素
            for(int i = 0; i < len; i++) {
                int v = ch[u][i];
                int f = ch[fail[u]][i];
                if(v != 0) {
                	if(cnt[f] == true) cnt[v] = true;
                    fail[v] = f;
                    q.offer(v);
                }
                else ch[u][i] = f; 
            }
        }
    }
    
    Matrix build_matrix() {
    	int sz = tot + 1;
        Matrix ans = new Matrix(sz, 0);
        for(int i = 0; i < sz; i++) {
            if(cnt[i]) continue;
            for(int j = 0;j < len; j++) {
                int v = ch[i][j];
                if(!cnt[v])
                	ans.m[i][v] = ans.m[i][v].add(BigInteger.ONE);
            }
        }
        return ans;
    }
    
    Matrix mul(Matrix s1, Matrix s2) {
    	int sz = tot + 1;
        Matrix ans = new Matrix(sz, 0);
        for(int i = 0; i < sz; i++)
            for(int j = 0; j < sz; j++)
                for(int k = 0; k < sz; k++)
                	ans.m[i][j] = ans.m[i][j].add(s1.m[i][k].multiply(s2.m[k][j]));
        return ans;
    }
    
    Matrix matrix_pow(Matrix a, int b) {
    	int sz = tot + 1;
    	Matrix ans = new Matrix(sz, 1);
        while(b != 0) {
            if(b%2 == 1) ans = mul(ans,a);
            a = mul(a,a);
            b /= 2;
        }
        return ans;
    }
    
}

public class Main {
	
	public static void main(String[] args) {
		Scanner cin = new Scanner(new BufferedInputStream(System.in), "ISO-8859-1");
		//Scanner cin = new Scanner(System.in); 會RE
		Trie ac = new Trie();
		ac.len = cin.nextInt();
		int m = cin.nextInt();
		int n = cin.nextInt();
		String s = cin.next();
		for(int i = 0; i < ac.len; i++) {
			ac.mp[s.charAt(i)] = i;
        }
        for(int i = 1; i <= n; i++) {
        	String t = cin.next(); // 病毒串
            ac.ins(t);
        }
        ac.build_fail();
        Matrix a = ac.build_matrix(); // 鄰接矩陣
        Matrix ans = ac.matrix_pow(a,m);
        BigInteger sum = BigInteger.ZERO;
        int sz = ac.tot + 1;
        for(int i = 0; i < sz; i++) {
        	sum = sum.add(ans.m[0][i]);
        }
        System.out.println(sum);
	}
	
}

C++代碼(沒寫大數,不能AC,只是作爲java代碼的“翻譯”):

#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>
#include <cmath>
using namespace std;
typedef unsigned long long ll;
typedef unsigned char uc;
const int N=105,M=55,K=256;
int mp[K];
int len,n,m;
struct matrix
{
    ll m[N][N];
    matrix()
    {
        memset(m,0,sizeof(m));
    }
};
struct trie
{
    int ch[N][M];
    int fail[N];
    bool cnt[N];
    int tot;
    queue<int>q;
    void init()
    {
        memset(cnt,0,sizeof(cnt));
        memset(fail,0,sizeof(fail));
        memset(ch,0,sizeof(ch));
        tot=0;
    }
    void ins(uc s[])
    {
        int u=0;
        for(int i=0;s[i];i++)
        {
            int x=mp[s[i]]; // ASCII碼 -> 0~len-1(len是字母表長度)
            if(!ch[u][x])ch[u][x]=++tot;
            u=ch[u][x];
        }
        cnt[u]=1;
    }
    void build_fail()
    {
        for(int i=0;i<len;i++)
        {
            if(ch[0][i])
                q.push(ch[0][i]);
        }
        while(!q.empty())
        {
            int u=q.front();q.pop();
            for(int i=0;i<len;i++)
            {
                int &v=ch[u][i];
                int f=ch[fail[u]][i];
                if(v)
                {
                    cnt[v]|=cnt[f];
                    fail[v]=f;
                    q.push(v);
                }
                else v=f;
            }
        }
    }
    matrix build_matrix()
    {
        int sz=tot+1;
        matrix ans=matrix();
        for(int i=0;i<sz;i++)
        {
            if(cnt[i])continue;
            for(int j=0;j<len;j++)
            {
                int v=ch[i][j];
                if(!cnt[v])
                    ans.m[i][v]++;
            }
        }
        return ans;
    }
}ac;
matrix mul(matrix s1,matrix s2)
{
    int sz=ac.tot+1;
    matrix ans=matrix();
    for(int i=0;i<sz;i++)
        for(int j=0;j<sz;j++)
            for(int k=0;k<sz;k++)
                ans.m[i][j]+=s1.m[i][k]*s2.m[k][j];
    return ans;
}
matrix matrix_pow(matrix a,int b)
{
    int sz=ac.tot+1;
    matrix ans=matrix();
    for(int i=0;i<sz;i++)
        ans.m[i][i]=1;
    while(b)
    {
        if(b&1)ans=mul(ans,a);
        a=mul(a,a);
        b/=2;
    }
    return ans;
}
int main()
{
    ios::sync_with_stdio(false);
    uc c,t[12];
    cin>>len>>m>>n;
    memset(mp,0,sizeof(mp));
    ac.init();
    for(int i=0;i<len;i++)
    {
        cin>>c;
        mp[c]=i;
    }
    for(int i=1;i<=n;i++)
    {
        cin>>t;
        ac.ins(t);
    }
    ac.build_fail();
    int sz=ac.tot+1;
    matrix a=ac.build_matrix();
    matrix ans=matrix_pow(a,m);
    ll sum=0;
    for(int i=0;i<sz;i++)
        sum+=ans.m[0][i];
    printf("%llu",sum);
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章