poj 1625

poj 1625


題意:
現有一個奇葩的國家。總統不允許任何含有非法單詞的句子出現。假定這個國家的句子長度固定爲m。現給定大小爲n 的字符集,和p個長度爲min(m,10)的非法單詞。任何包含非法單詞的句子都是非法的。求所有合法句子的數目。
爲了解決這個問題
我們先假定字符集{‘A’,’G’,’C’,’T’}
非法單詞有2個,分別是”ACG”, “C”,句子長度爲m
我們將非法串建立trie,得到如下的圖


這張trie圖和普通的trie樹不同之處在於,它補全了所有trie樹中原本不存在的邊。每一個狀態節點的某一個子節點若不存在,則將其指向這個狀態節點的fail節點所對應的那個相同的子節點。即 若 ch[r][c]==0,ch[r][c]=ch[f[r]][c](大白書的寫法) 這樣一來原來的trie樹變成了一張有向的狀態轉移圖。從trie圖中的根節點開始,沿着有向邊走m步,即可得到一個長度爲m的字串。
我們知道的是,自動機上的字串匹配其實就是各個狀態間的轉移,後綴與前綴之間的互相匹配的過程。不包含非法單詞其實可以理解爲,在走m步的過程中沒有經過任何trie樹上的單詞結尾狀態節點,也就是上面trie圖中所標出的紅色節點。因爲無論trie圖中的狀態從何轉移而來,其已經匹配過的字串必定包含從根節點到當前狀態節點的路徑所代表的非法單詞。而包含非法單詞的句子是不被允許的,因此在沿着有向邊轉移m次的過程中,不能轉向任何非法單詞結尾狀態節點。同理,也不能從任何的非法節點轉出。
現在問題變成了,從根節點開始轉移m次至合法節點,共有多少種不同的方案。設計二維狀態dp[i][j]表示第i步走到第j個節點的方案總數。
轉移方程:dp[i][j] += dp[i-1][k](k到j有一條有向邊且k,j均不爲非法節點)
dp邊界狀態爲:dp[0][0]=1; dp[0][k] = 0;(0<k<sz,sz爲所有狀態節點總數目)


需要注意的有兩點:
1. 可能某一個非法單詞包含了另一個非法單詞。如上圖中’ACG’和單詞’C’,在這種情況下,2狀態節點也需要被標記爲非法的狀態。因爲匹配了’ACG’就一定匹配了’C’。在實際過程中只需要判下last[j]==root
2. 題中沒有提到取模,因此要使用大數相加。


下面帖代碼:

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <queue>
#define maxnode 105
#define sigma_size 52
using namespace std;

const int base = 10;

int n, m, p;

int ch[maxnode][sigma_size];
int val[maxnode];
int sz;

int f[maxnode];
int last[maxnode];

char charset[maxnode];
char ban[12];

int get(char a)
{
    for(int i = 0; i < n; i++)
        if(charset[i] == a)
            return i;
    return -1;
}
void initial()
{
    memset(ch[0], 0, sizeof(ch[0]));
    sz = 1;
}
void insert(char *s)
{
    int l = strlen(s);
    int u = 0;
    for(int i = 0; i < l; i++)
    {
        int c = get(s[i]);
        if(!ch[u][c])
        {
            memset(ch[sz], 0, sizeof(ch[sz]));
            val[sz] = 0;
            ch[u][c] = sz++;
        }
        u = ch[u][c];
    }
    val[u] = 1;
}
void getfail()
{
    queue<int >q;
    f[0] = 0;
    for(int c = 0; c < n; c++)
    {
        int u = ch[0][c];
        if(u)
        {
            f[u] = 0;
            q.push(u);
            last[u] = 0;
        }
    }
    while(!q.empty())
    {
        int r = q.front();
        q.pop();
        for(int c = 0; c < n; c++)
        {
            int u = ch[r][c];
            if(!u)
            {
                ch[r][c] = ch[f[r]][c];
                continue;
            }
            q.push(u);
            int v = f[r];
            while(v && !ch[v][c])
                v = f[v];
            f[u] = ch[v][c];
            last[u] = val[f[u]] ? f[u] : last[f[u]];
            // if(last[u])
            //   val[u] = 1;
        }
    }
}
struct BigInt
{
    int v[maxnode], len;
    BigInt(int r = 0)
    {
        memset(v, 0, sizeof(v));
        for(len = 0; r > 0; r /= base)v[len++] = r % base;
    }
    BigInt operator + (const BigInt &a)
    {
        BigInt ans;
        int i , c = 0;
        for(i = 0; i < len || i < a.len || c > 0; i++)
        {
            if(i < len)c += v[i];
            if(i < a.len)c += a.v[i];
            ans.v[i] = c % base;
            c /= base;
        }
        ans.len = i;
        return ans;
    }
    void print()
    {
        printf("%d", len == 0 ? 0 : v[len - 1]);
        for(int i = len - 2; i >= 0; i--)
            printf("%d", v[i]);
        printf("\n");
    }
};
BigInt dp[52][maxnode];
int check(int k, int j)
{
    int ans = 0;
    for(int i = 0; i < n; i++)
        if(ch[k][i] == j)
            ans++;
    return ans;
}
int main()
{
    initial();
    scanf("%d%d%d", &n, &m, &p);
    scanf("%s", charset);
    while(p--)
    {
        scanf("%s", ban);
        insert(ban);
    }
    getfail();
    //初始化邊界條件
    for(int i = 0; i <= m; i++)
        for(int j = 0; j < sz; j++)
        {
            dp[i][j] = BigInt();
        }
    dp[0][0] = BigInt(1);

    //dp過程
    for(int i = 1 ; i <= m; i++)
    {
        for(int j = 0; j < sz; j++)
        {
            if(val[j] || val[last[j]])
                continue;
            for(int k = 0; k < sz; k++)
            {
                if(val[last[k]] || val[k])
                    continue;
                int ti=check(k, j);
                for(int c=0;c<ti;c++)
                    dp[i][j] = dp[i][j] + dp[i - 1][k];
            }

        }
    }
    BigInt ans = BigInt();
    for(int i = 0; i < sz; i++)
    {
        ans = ans + dp[m][i];
    }
    ans.print();
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章