語音識別—Viterbi解碼

                                                                             Viterbi解碼理論與實戰

  筆者最近着手研究基於HMM的語音識別系統,之前博文基於C++工具手寫了提取MFCC語音特徵(具體可以觀看之前博文),最同時,也對基於GMM-HMM的語音識別訓練過程進行了理論推導,現本文對基於Viterbi的解碼方法進行詳細的研究。

  曾看過很多語音識別書(餘棟的語音識別實戰、陳果 果的kaldi實戰以及張雪英的數位語音處理等書)、課程(七月在線、深藍學院等課程),上述研究對於Viterbi應用於語音識別僅停留於理論階段,筆者針對他們的理論應用於語音識別有保留的贊同,以下則是本人對Viterbi算法的理論理解與工程實現方法,持有不同意見的讀者不妨對本人提出建議:

一、Viterbi解碼理論部分

1.1 餘棟書對Viterbi算法的理論介紹

  首先看上書對Viterbi解碼提供的僞代碼部分:

  對上述僞代碼筆者曾一頭霧水,對其中很多念理解不清楚,實際上其具體含義可以理解爲:從狀態初始至結尾找出概率最大的路徑,該路徑可以通過回溯找出對應狀態序列,最後進行反轉得到正序語音序列。

  對於此僞代碼,我相信如果不經歷過自己手寫過Viterbi算法,很難理解其中具體的實現過程,現筆者開始對其進行深入剖析。

1.2 個人見解

  值得指出的是,筆者很好在相關書上找出究竟是怎樣對Viterbi概率進行準確的求解,即很少有介紹狀態之間概率究竟是如果進行跳轉以及計算的。值得指出的是:狀態之間的計算僅是當前狀態與之前一幀部分狀態之間的計算,因爲累積會造成概率小時或者概率爆炸的問題,故對其進行log計算,故累積變爲累加,其可解決上述問題。

 上述中何爲“當前狀態與之前一幀部分狀態之間”,其實這句話對於理解Viterbi計算有很重要的作用,部分即爲之前一幀僅有部分狀態與當前狀態有關係(即爲有跳轉)。必須說明的是,最後的每個狀態的累加結果包含三部分:弧上概率(爲防止其他概率不存在,一般會對弧上概率進行初始化,可以爲負無窮或正無窮)、gmm概率(即爲特徵與高斯混合模型得到的每幀中每個狀態對應的概率密度函數,pdf)以及狀態轉移概率(狀態轉移即對應與當前狀態想關的之前狀態)。

  因爲Viterbi算法涉及回溯找出概率最大對應的狀態序列(實際上說是狀態序列實際上是不太準確的,因爲在解碼過程中,每個HMM可能有N個狀態,因此實際解碼過程中使用的是弧序號來替代狀態號,當然二者之間存在互相映射關係),因此對於回溯過程中容器的選擇至關重要,其主要涉及弧號以及概率的保存(可以設置自定義變量保存二者),然後使用矩陣將自定義變量按順序保存至矩陣容器中,因此後文實現過程中使用Matrix<VitCell>形式保存自定義變量。

二、Viterbi解碼實戰部分

2.1 特徵數據

  首先本文參考是的哥倫比亞大學的語音識別課程進行Viterbi解碼,其首先針對一句語料進行解碼,最後提取該句語料特徵:68*12,其中,68表示幀數,12表示MFCC特徵。

2.2 symbol 序列

  音素序列即爲建模單元,針對不同語音識別系統,可能有不同的建模單元,常見的有:狀態、音素、音節、字等,筆者使用的symbol詞典格式如下所示:

  音素與其序列之間爲互相映射關係,後期可以直接通過音素序號得到對應的音素,而音素序號可以通過訓練得到模型得到狀態對應的音素,音素序號與狀態也爲互相映射關係,但正如上文所述,實際應用過程中使用弧序號進行存儲矩陣容器,然而弧號可與狀態號進行互相映射。

  很繞腦,但由上文可得:可以通過弧號得到對應的音素序列

2.3 graph 序列

  由下表第四行可知,每行以此保存着初始狀態、跳轉狀態、對應GMM序號以及對應的音素。

  值得說明的是:graph表中最後保存的是HMM狀態鏈的終止狀態,終止狀態不發生狀態跳轉,其結果如下:

  由上圖可知,本文使用的graph序列終止狀態爲47與84兩個不同的狀態,後期進行Viterbi計算得到最大的終止狀態。

2.4 gmm 訓練結果

  通過EM算法不斷迭代存儲GMM模型,以此包含:GMM對應的狀態序號、各GMM權重、以及均值與方差,,其中均值與方差存儲結果如下圖所示:

  因爲語音識別常使用對角GMM對特徵進行計算,其中每行有24個元素(均值與方差維度必須與特徵維度一直,否則無法計算每幀屬於某個狀態的pdf),奇數列、偶數列分別表示方差與均值,因爲是對角線元素,故方差值必爲正值,讀者可以參考哥大的代碼,建立graph類,以下爲模型建立核心過程:

string Graph::read(istream& inStrm, const string& name) {
    clear();
    string retStr;
    string lineStr;
    vector<string> fieldList;
    while (true) {
        int peekChar = inStrm.peek();
        if (peekChar != '#')
            break;
        getline(inStrm, lineStr);
        split_string(lineStr, fieldList);
        if ((fieldList.size() == 3) && (fieldList[0] == "#") &&
            (fieldList[1] == "name:")) {
            if (!name.empty() && (fieldList[2] != name))
                throw runtime_error(str(format("Unexpected FSM name: %s/%s") %
                    name % fieldList[2]));
            if (!retStr.empty())
                throw runtime_error(str(format("FSM has two names: %s/%s") %
                    retStr % fieldList[2]));
            retStr = fieldList[2];
        }
    }
    int lastIdx = -1;
    vector<pair<int, Arc>> arcList;
    double logFactor = -log(10.0);
    while (true) {
        int peekChar = inStrm.peek();
        if ((peekChar == '#') || (peekChar == EOF))
            break;
        getline(inStrm, lineStr);
        split_string(lineStr, fieldList);
        if (!fieldList.size())
            continue;
        try {
            int srcIdx = lexical_cast<int>(fieldList[0]);
            if (srcIdx < 0)
                throw runtime_error("Negative state index in FSM: " + lineStr);
            if (m_start == -1)
                m_start = srcIdx;
            if (srcIdx > lastIdx)
                lastIdx = srcIdx;
            if (fieldList.size() <= 2) {
                double logProb = (fieldList.size() > 1) ?
                    lexical_cast<double>(fieldList[1]) * logFactor : 0.0;
                if (m_finalLogProbs.find(srcIdx) != m_finalLogProbs.end())
                    throw runtime_error("Dup final state in FSM: " + lineStr);
                m_finalLogProbs[srcIdx] = logProb;
                continue;
            }
            
            if ((fieldList.size() == 3) || (fieldList.size() > 5))
                throw runtime_error("Invalid num fields in FSM: " + lineStr);
            unsigned dstIdx = lexical_cast<int>(fieldList[1]);
            if (dstIdx < 0)
                throw runtime_error("Negative state index in FSM: " + lineStr);
            if (dstIdx > lastIdx)
                lastIdx = dstIdx;
            int gmmIdx = -1;
            const string& gmmStr = fieldList[2];
            if ((gmmStr.length() >= 3) && (gmmStr.length() <= 9) &&
                (gmmStr[0] == '<') && (gmmStr[gmmStr.length() - 1] == '>') &&
                (string("epsilon").substr(0, gmmStr.length() - 2) ==
                    gmmStr.substr(1, gmmStr.length() - 2))) {
                ;
            }
            else {
                gmmIdx = lexical_cast<int>(gmmStr);
                if (gmmIdx < 0)
                    throw runtime_error("Negative GMM index in FSM: " +
                        lineStr);
                int wordIdx = !m_symTable->empty() ?
                    m_symTable->get_index(fieldList[3]) : 0;
                if (wordIdx < 0)
                    throw runtime_error("OOV word in FSM: " + lineStr);
                double logProb = (fieldList.size() > 4) ?
                    lexical_cast<double>(fieldList[4]) * logFactor : 0.0;
                Arc arc(dstIdx, gmmIdx, wordIdx, logProb);
                arcList.push_back(make_pair(srcIdx, arc));
            }
        }
        catch (bad_lexical_cast&)
        {
            throw runtime_error("Invalid type for field in FSM: " + lineStr);
        }
    }
    if (m_start < 0)
        throw runtime_error("Empty FSM.");
    //lastIdx:122;
    int stateCnt = lastIdx + 1;
    m_stateMap.reserve(stateCnt);
    m_arcList.reserve(arcList.size());
    sort(arcList.begin(), arcList.end(), CompareArcs());
    for (int arcIdx = 0; arcIdx < (int)arcList.size(); ++arcIdx) {
        m_arcList.push_back(arcList[arcIdx].second);
        //arcList[arcIdx].second.pringResults();
        int srcIdx = arcList[arcIdx].first;
        while ((int)m_stateMap.size() <= srcIdx)
            m_stateMap.push_back(arcIdx);
    }
    //printVector<>(m_stateMap);
    while ((int)m_stateMap.size() < stateCnt)
        m_stateMap.push_back(arcList.size());
    assert(((int)m_stateMap.size() == stateCnt) &&
        (m_arcList.size() == arcList.size()));
    for (int stateIdx = 0; stateIdx < stateCnt; ++stateIdx) {
        int minArcIdx = get_min_arc_index(stateIdx);
        int maxArcIdx = get_max_arc_index(stateIdx);
        for (int arcIdx = minArcIdx; arcIdx < maxArcIdx; ++arcIdx)
            assert(arcList[arcIdx].first == stateIdx);
    }
    return retStr;
}

  上述代碼主要將Graph表中元素存儲到同的容器中,讀者可以根據需求自己建立模型。

2.5 chart 容器

  chart是Viterbi解碼的核心解碼圖,其可以理解爲格子圖,其維度爲:69*123,因爲解碼需要初始位置,故將68幀語音特徵之前建立一幀作爲起始位置,其中起始值爲(-1, 0),其中-1表示弧號,0表示log似然值,其具體建立過程如下所示:

bool Lab2VitMain::init_utt() {
    if (m_audioStrm.peek() == EOF) {
        return false;
    }
    m_idStr = read_float_matrix(m_audioStrm, m_inAudio);
    cout << "Processing utterance ID: " << m_idStr << endl;
    m_frontEnd.get_feats(m_inAudio, m_feats);
    if (m_feats.size2() != m_gmmSet.get_dim_count())
        throw runtime_error("Mismatch in GMM and feat dim.");
    if (m_doAlign) {
        if (m_graphStrm.peek() == EOF)
            throw runtime_error(
                "Mismatch in number of audio files "
                "and FSM's.");
        m_graph.read(m_graphStrm, m_idStr);
    }
    if (m_graph.get_gmm_count() > m_gmmSet.get_gmm_count())
        throw runtime_error(
            "Mismatch in number of GMM's between "
            "FSM and GmmSet.");
    //m_gmmProbs矩陣維度爲68*102,即爲當前幀屬於某個狀態的pdf;
    m_gmmSet.calc_gmm_probs(m_feats, m_gmmProbs);
    m_chart.resize(m_feats.size1() + 1, m_graph.get_state_count());
    m_chart.clear();
    if (m_graph.get_start_state() < 0)
        throw runtime_error("Graph has no start state.");
    return true;
}

  chart矩陣爲什麼對於語音識別解碼至關重要,該矩陣每個元素保存的是弧號與至此狀態最大的似然概率值,如果對chart容器有問題,可以加入微信解碼羣研究。

2.6 回溯 

  實際上chart圖最後一幀即爲最大似然概率對應的弧號,可以基於此回溯得到完成的弧號序列,弧號與狀態之間存在映射關係,實際上每個弧號對應的狀態的起始狀態不就是前一幀的弧號對應的終止狀態嗎(這裏大家可以仔細理解下)?,基於此可以以此得到狀態最大似然概率對應的弧序列,進而映射爲狀態序列和音素序列,其模型具體建立過程如下所示:

double viterbi_backtrace(const Graph& graph, matrix<VitCell>& chart,
    vector<int>& outLabelList, bool doAlign) {
    int frmCnt = chart.size1() - 1;
    int stateCnt = chart.size2();
    //finalStates存儲終止狀態對應的序號,且對其進行排序;
    vector<int> finalStates;
    int finalCnt = graph.get_final_state_list(finalStates);   
    double bestLogProb = g_zeroLogProb;
    int bestFinalState = -1;
    for (int finalIdx = 0; finalIdx < finalCnt; ++finalIdx) {
        int stateIdx = finalStates[finalIdx];
        if (chart(frmCnt, stateIdx).get_log_prob() == g_zeroLogProb) continue;
        //curLogProb表示終止狀態對應的似然值與弧上概率的累加值;
        //加上弧上概率是因爲終止狀態再進行log似然值累加時,終止狀態上並未添加弧上概率;
        double curLogProb = chart(frmCnt, stateIdx).get_log_prob() +
            graph.get_final_log_prob(stateIdx);
        if (curLogProb > bestLogProb)
            bestLogProb = curLogProb, bestFinalState = stateIdx;
    }
    if (bestFinalState < 0) throw runtime_error("No complete paths found.");
    outLabelList.clear();
    int stateIdx = bestFinalState;
    for (int frmIdx = frmCnt; --frmIdx >= 0;) {
        assert((stateIdx >= 0) && (stateIdx < stateCnt));
        int arcId = chart(frmIdx + 1, stateIdx).get_arc_id();
        Arc arc;
        graph.get_arc(arcId, arc);
        assert((int)arc.get_dst_state() == stateIdx);
        if (doAlign) {
            throw runtime_error("Expect all arcs to have GMM.");
            outLabelList.push_back(arc.get_gmm());
        }
        else if (arc.get_word() > 0) {
            outLabelList.push_back(arc.get_word());
        }
        stateIdx = graph.get_src_state(arcId);
        cout << stateIdx << endl;
    }
    if (stateIdx != graph.get_start_state())
        throw runtime_error("Backtrace does not end at start state.");
    reverse(outLabelList.begin(), outLabelList.end());
    return bestLogProb;
}

至此Viterbi算法理論與實踐簡要介紹完畢

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章