吉布斯採樣器

<span style="font-size:18px;">/*
 * (C) Copyright 2005, Gregor Heinrich (gregor :: arbylon : net) (This file is
 * part of the org.knowceans experimental software packages.)
 */
/*
 * LdaGibbsSampler is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the Free
 * Software Foundation; either version 2 of the License, or (at your option) any
 * later version.
 */
/*
 * LdaGibbsSampler is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
 * details.
 */
/*
 * You should have received a copy of the GNU General Public License along with
 * this program; if not, write to the Free Software Foundation, Inc., 59 Temple
 * Place, Suite 330, Boston, MA 02111-1307 USA
 */

/*
 * Created on Mar 6, 2005
 */
package lda;

import java.text.DecimalFormat;
import java.text.NumberFormat;

/**
 * Gibbs sampler for estimating the best assignments of topics for words and
 * documents in a corpus. The algorithm is introduced in Tom Griffiths' paper
 * "Gibbs sampling in the generative model of Latent Dirichlet Allocation"
 * (2002).<br>
 * Gibbs sampler採樣算法的實現
 *
 * @author heinrich
 */
public class LdaGibbsSampler
{

    /**
     * document data (term lists)<br>
     * 文檔
     */
    int[][] documents;

    /**
     * vocabulary size<br>
     * 詞表大小
     */
    int V;

    /**
     * number of topics<br>
     * 主題數目
     */
    int K;

    /**
     * Dirichlet parameter (document--topic associations)<br>
     * 文檔——主題參數
     */
    double alpha = 2.0;

    /**
     * Dirichlet parameter (topic--term associations)<br>
     * 主題——詞語參數
     */
    double beta = 0.5;

    /**
     * topic assignments for each word.<br>
     * 每個詞語的主題 z[i][j] := 文檔i的第j個詞語的主題編號
     */
    int z[][];

    /**
     * cwt[i][j] number of instances of word i (term?) assigned to topic j.<br>
     * 計數器,nw[i][j] := 詞語i歸入主題j的次數
     */
    int[][] nw;

    /**
     * na[i][j] number of words in document i assigned to topic j.<br>
     * 計數器,nd[i][j] := 文檔[i]中歸入主題j的詞語的個數
     */
    int[][] nd;

    /**
     * nwsum[j] total number of words assigned to topic j.<br>
     * 計數器,nwsum[j] := 歸入主題j詞語的個數
     */
    int[] nwsum;

    /**
     * nasum[i] total number of words in document i.<br>
     * 計數器,ndsum[i] := 文檔i中全部詞語的數量
     */
    int[] ndsum;

    /**
     * cumulative statistics of theta<br>
     * theta的累積量
     */
    double[][] thetasum;

    /**
     * cumulative statistics of phi<br>
     * phi的累積量
     */
    double[][] phisum;

    /**
     * size of statistics<br>
     * 樣本容量
     */
    int numstats;

    /**
     * sampling lag (?)<br>
     * 多久更新一次統計量
     */
    private static int THIN_INTERVAL = 20;

    /**
     * burn-in period<br>
     * 收斂前的迭代次數
     */
    private static int BURN_IN = 100;

    /**
     * max iterations<br>
     * 最大迭代次數
     */
    private static int ITERATIONS = 1000;

    /**
     * sample lag (if -1 only one sample taken)<br>
     * 最後的模型個數(取收斂後的n個迭代的參數做平均可以使得模型質量更高)
     */
    private static int SAMPLE_LAG = 10;

    private static int dispcol = 0;

    /**
     * Initialise the Gibbs sampler with data.<br>
     * 用數據初始化採樣器
     *
     * @param documents 文檔
     * @param V         vocabulary size 詞表大小
     */
    public LdaGibbsSampler(int[][] documents, int V)
    {

        this.documents = documents;
        this.V = V;
    }

    /**
     * Initialisation: Must start with an assignment of observations to topics ?
     * Many alternatives are possible, I chose to perform random assignments
     * with equal probabilities<br>
     * 隨機初始化狀態
     *
     * @param K number of topics K個主題
     */
    public void initialState(int K)
    {
        int M = documents.length;

        // initialise count variables. 初始化計數器
        nw = new int[V][K];
        nd = new int[M][K];
        nwsum = new int[K];
        ndsum = new int[M];

        // The z_i are are initialised to values in [1,K] to determine the
        // initial state of the Markov chain.

        z = new int[M][];   // z_i := 1到K之間的值,表示馬氏鏈的初始狀態
        for (int m = 0; m < M; m++)
        {
            int N = documents[m].length;
            z[m] = new int[N];
            for (int n = 0; n < N; n++)
            {
                int topic = (int) (Math.random() * K);
                z[m][n] = topic;
                // number of instances of word i assigned to topic j
                nw[documents[m][n]][topic]++;
                // number of words in document i assigned to topic j.
                nd[m][topic]++;
                // total number of words assigned to topic j.
                nwsum[topic]++;
            }
            // total number of words in document i
            ndsum[m] = N;
        }
    }

    public void gibbs(int K)
    {
        gibbs(K, 2.0, 0.5);
    }

    /**
     * Main method: Select initial state ? Repeat a large number of times: 1.
     * Select an element 2. Update conditional on other elements. If
     * appropriate, output summary for each run.<br>
     * 採樣
     *
     * @param K     number of topics 主題數
     * @param alpha symmetric prior parameter on document--topic associations 對稱文檔——主題先驗概率?
     * @param beta  symmetric prior parameter on topic--term associations 對稱主題——詞語先驗概率?
     */
    public void gibbs(int K, double alpha, double beta)
    {
        this.K = K;
        this.alpha = alpha;
        this.beta = beta;

        // init sampler statistics  分配內存
        if (SAMPLE_LAG > 0)
        {
            thetasum = new double[documents.length][K];
            phisum = new double[K][V];
            numstats = 0;
        }

        // initial state of the Markov chain:
        initialState(K);

        System.out.println("Sampling " + ITERATIONS
                                   + " iterations with burn-in of " + BURN_IN + " (B/S="
                                   + THIN_INTERVAL + ").");

        for (int i = 0; i < ITERATIONS; i++)
        {

            // for all z_i
            for (int m = 0; m < z.length; m++)
            {
                for (int n = 0; n < z[m].length; n++)
                {

                    // (z_i = z[m][n])
                    // sample from p(z_i|z_-i, w)
                    int topic = sampleFullConditional(m, n);
                    z[m][n] = topic;
                }
            }

            if ((i < BURN_IN) && (i % THIN_INTERVAL == 0))
            {
                System.out.print("B");
                dispcol++;
            }
            // display progress
            if ((i > BURN_IN) && (i % THIN_INTERVAL == 0))
            {
                System.out.print("S");
                dispcol++;
            }
            // get statistics after burn-in
            if ((i > BURN_IN) && (SAMPLE_LAG > 0) && (i % SAMPLE_LAG == 0))
            {
                updateParams();
                System.out.print("|");
                if (i % THIN_INTERVAL != 0)
                    dispcol++;
            }
            if (dispcol >= 100)
            {
                System.out.println();
                dispcol = 0;
            }
        }
        System.out.println();
    }

    /**
     * Sample a topic z_i from the full conditional distribution: p(z_i = j |
     * z_-i, w) = (n_-i,j(w_i) + beta)/(n_-i,j(.) + W * beta) * (n_-i,j(d_i) +
     * alpha)/(n_-i,.(d_i) + K * alpha) <br>
     * 根據上述公式計算文檔m中第n個詞語的主題的完全條件分佈,輸出最可能的主題
     *
     * @param m document
     * @param n word
     */
    private int sampleFullConditional(int m, int n)
    {

        // remove z_i from the count variables  先將這個詞從計數器中抹掉
        int topic = z[m][n];
        nw[documents[m][n]][topic]--;
        nd[m][topic]--;
        nwsum[topic]--;
        ndsum[m]--;

        // do multinomial sampling via cumulative method: 通過多項式方法採樣多項式分佈
        double[] p = new double[K];
        for (int k = 0; k < K; k++)
        {
            p[k] = (nw[documents[m][n]][k] + beta) / (nwsum[k] + V * beta)
                    * (nd[m][k] + alpha) / (ndsum[m] + K * alpha);
        }
        // cumulate multinomial parameters  累加多項式分佈的參數
        for (int k = 1; k < p.length; k++)
        {
            p[k] += p[k - 1];
        }
        // scaled sample because of unnormalised p[] 正則化
        double u = Math.random() * p[K - 1];
        for (topic = 0; topic < p.length; topic++)
        {
            if (u < p[topic])
                break;
        }

        // add newly estimated z_i to count variables   將重新估計的該詞語加入計數器
        nw[documents[m][n]][topic]++;
        nd[m][topic]++;
        nwsum[topic]++;
        ndsum[m]++;

        return topic;
    }

    /**
     * Add to the statistics the values of theta and phi for the current state.<br>
     * 更新參數
     */
    private void updateParams()
    {
        for (int m = 0; m < documents.length; m++)
        {
            for (int k = 0; k < K; k++)
            {
                thetasum[m][k] += (nd[m][k] + alpha) / (ndsum[m] + K * alpha);
            }
        }
        for (int k = 0; k < K; k++)
        {
            for (int w = 0; w < V; w++)
            {
                phisum[k][w] += (nw[w][k] + beta) / (nwsum[k] + V * beta);
            }
        }
        numstats++;
    }

    /**
     * Retrieve estimated document--topic associations. If sample lag > 0 then
     * the mean value of all sampled statistics for theta[][] is taken.<br>
     * 獲取文檔——主題矩陣
     *
     * @return theta multinomial mixture of document topics (M x K)
     */
    public double[][] getTheta()
    {
        double[][] theta = new double[documents.length][K];

        if (SAMPLE_LAG > 0)
        {
            for (int m = 0; m < documents.length; m++)
            {
                for (int k = 0; k < K; k++)
                {
                    theta[m][k] = thetasum[m][k] / numstats;
                }
            }

        }
        else
        {
            for (int m = 0; m < documents.length; m++)
            {
                for (int k = 0; k < K; k++)
                {
                    theta[m][k] = (nd[m][k] + alpha) / (ndsum[m] + K * alpha);
                }
            }
        }

        return theta;
    }

    /**
     * Retrieve estimated topic--word associations. If sample lag > 0 then the
     * mean value of all sampled statistics for phi[][] is taken.<br>
     * 獲取主題——詞語矩陣
     *
     * @return phi multinomial mixture of topic words (K x V)
     */
    public double[][] getPhi()
    {
        double[][] phi = new double[K][V];
        if (SAMPLE_LAG > 0)
        {
            for (int k = 0; k < K; k++)
            {
                for (int w = 0; w < V; w++)
                {
                    phi[k][w] = phisum[k][w] / numstats;
                }
            }
        }
        else
        {
            for (int k = 0; k < K; k++)
            {
                for (int w = 0; w < V; w++)
                {
                    phi[k][w] = (nw[w][k] + beta) / (nwsum[k] + V * beta);
                }
            }
        }
        return phi;
    }

    /**
     *輸出多項式數據
     * @param data vector of evidence
     * @param fmax max frequency in display
     * @return the scaled histogram bin values
     */
    public static void hist(double[] data, int fmax)
    {

        double[] hist = new double[data.length];
        // scale maximum
        double hmax = 0;
        for (int i = 0; i < data.length; i++)
        {
            hmax = Math.max(data[i], hmax);
        }
        double shrink = fmax / hmax;
        for (int i = 0; i < data.length; i++)
        {
            hist[i] = shrink * data[i];
        }

        NumberFormat nf = new DecimalFormat("00");
        String scale = "";
        for (int i = 1; i < fmax / 10 + 1; i++)
        {
            scale += "    .    " + i % 10;
        }

        System.out.println("x" + nf.format(hmax / fmax) + "\t0" + scale);
        for (int i = 0; i < hist.length; i++)
        {
            System.out.print(i + "\t|");
            for (int j = 0; j < Math.round(hist[i]); j++)
            {
                if ((j + 1) % 10 == 0)
                    System.out.print("]");
                else
                    System.out.print("|");
            }
            System.out.println();
        }
    }

    /**
     * Configure the gibbs sampler<br>
     * 配置採樣器
     *
     * @param iterations   number of total iterations
     * @param burnIn       number of burn-in iterations
     * @param thinInterval update statistics interval
     * @param sampleLag    sample interval (-1 for just one sample at the end)
     */
    public void configure(int iterations, int burnIn, int thinInterval,
                          int sampleLag)
    {
        ITERATIONS = iterations;
        BURN_IN = burnIn;
        THIN_INTERVAL = thinInterval;
        SAMPLE_LAG = sampleLag;
    }

    /**
     * Inference a new document by a pre-trained phi matrix
     *
     * @param phi pre-trained phi matrix
     * @param doc document
     * @return a p array
     */
    public static double[] inference(double alpha, double beta, double[][] phi, int[] doc)
    {
        int K = phi.length;
        int V = phi[0].length;
        // init

        // initialise count variables. 初始化計數器
        int[][] nw = new int[V][K];
        int[] nd = new int[K];
        int[] nwsum = new int[K];
        int ndsum = 0;

        // The z_i are are initialised to values in [1,K] to determine the
        // initial state of the Markov chain.

        int N = doc.length;
        int[] z = new int[N];   // z_i := 1到K之間的值,表示馬氏鏈的初始狀態
        for (int n = 0; n < N; n++)
        {
            int topic = (int) (Math.random() * K);
            z[n] = topic;
            // number of instances of word i assigned to topic j
            nw[doc[n]][topic]++;
            // number of words in document i assigned to topic j.
            nd[topic]++;
            // total number of words assigned to topic j.
            nwsum[topic]++;
        }
        // total number of words in document i
        ndsum = N;
        for (int i = 0; i < ITERATIONS; i++)
        {
            for (int n = 0; n < z.length; n++)
            {

                // (z_i = z[m][n])
                // sample from p(z_i|z_-i, w)
                // remove z_i from the count variables  先將這個詞從計數器中抹掉
                int topic = z[n];
                nw[doc[n]][topic]--;
                nd[topic]--;
                nwsum[topic]--;
                ndsum--;

                // do multinomial sampling via cumulative method: 通過多項式方法採樣多項式分佈
                double[] p = new double[K];
                for (int k = 0; k < K; k++)
                {
                    p[k] = (nw[doc[n]][k] + beta) / (nwsum[k] + V * beta)
                            * (nd[k] + alpha) / (ndsum + K * alpha);
                }
                // cumulate multinomial parameters  累加多項式分佈的參數
                for (int k = 1; k < p.length; k++)
                {
                    p[k] += p[k - 1];
                }
                // scaled sample because of unnormalised p[] 正則化
                double u = Math.random() * p[K - 1];
                for (topic = 0; topic < p.length; topic++)
                {
                    if (u < p[topic])
                        break;
                }

                // add newly estimated z_i to count variables   將重新估計的該詞語加入計數器
                nw[doc[n]][topic]++;
                nd[topic]++;
                nwsum[topic]++;
                ndsum++;
                z[n] = topic;
            }
        }

        double[] theta = new double[K];

        for (int k = 0; k < K; k++)
        {
            theta[k] = (nd[k] + alpha) / (ndsum + K * alpha);
        }
        return theta;
    }
    public static double[] inference(double[][] phi, int[] doc)
    {
        return inference(2.0, 0.5, phi, doc);
    }
    /**
     * 測試入口
     */
    public static void main(String[] args)
    {

        // words in documents
        int[][] documents = {
                {1, 4, 3, 2, 3, 1, 4, 3, 2, 3, 1, 4, 3, 2, 3, 6},
                {2, 2, 4, 2, 4, 2, 2, 2, 2, 4, 2, 2},
                {1, 6, 5, 6, 0, 1, 6, 5, 6, 0, 1, 6, 5, 6, 0, 0},
                {5, 6, 6, 2, 3, 3, 6, 5, 6, 2, 2, 6, 5, 6, 6, 6, 0},
                {2, 2, 4, 4, 4, 4, 1, 5, 5, 5, 5, 5, 5, 1, 1, 1, 1, 0},
                {5, 4, 2, 3, 4, 5, 6, 6, 5, 4, 3, 2}};  // 文檔的詞語id集合
        // vocabulary
        int V = 7;                                      // 詞表大小
        int M = documents.length;
        // # topics
        int K = 2;                                      // 主題數目
        // good values alpha = 2, beta = .5
        double alpha = 2;
        double beta = .5;

        System.out.println("Latent Dirichlet Allocation using Gibbs Sampling.");

        LdaGibbsSampler lda = new LdaGibbsSampler(documents, V);
        lda.configure(10000, 2000, 100, 10);
        lda.gibbs(K, alpha, beta);

        double[][] theta = lda.getTheta();
        double[][] phi = lda.getPhi();

        System.out.println();
        System.out.println();
        System.out.println("Document--Topic Associations, Theta[d][k] (alpha="
                                   + alpha + ")");
        System.out.print("d\\k\t");
        for (int m = 0; m < theta[0].length; m++)
        {
            System.out.print("   " + m % 10 + "    ");
        }
        System.out.println();
        for (int m = 0; m < theta.length; m++)
        {
            System.out.print(m + "\t");
            for (int k = 0; k < theta[m].length; k++)
            {
                // System.out.print(theta[m][k] + " ");
                System.out.print(shadeDouble(theta[m][k], 1) + " ");
            }
            System.out.println();
        }
        System.out.println();
        System.out.println("Topic--Term Associations, Phi[k][w] (beta=" + beta
                                   + ")");

        System.out.print("k\\w\t");
        for (int w = 0; w < phi[0].length; w++)
        {
            System.out.print("   " + w % 10 + "    ");
        }
        System.out.println();
        for (int k = 0; k < phi.length; k++)
        {
            System.out.print(k + "\t");
            for (int w = 0; w < phi[k].length; w++)
            {
                // System.out.print(phi[k][w] + " ");
                System.out.print(shadeDouble(phi[k][w], 1) + " ");
            }
            System.out.println();
        }
        // Let's inference a new document
        int[] aNewDocument = {2, 2, 4, 2, 4, 2, 2, 2, 2, 4, 2, 2};
        double[] newTheta = inference(alpha, beta, phi, aNewDocument);
        for (int k = 0; k < newTheta.length; k++)
        {
            // System.out.print(theta[m][k] + " ");
            System.out.print(shadeDouble(newTheta[k], 1) + " ");
        }
        System.out.println();
    }

    static String[] shades = {"     ", ".    ", ":    ", ":.   ", "::   ",
            "::.  ", ":::  ", ":::. ", ":::: ", "::::.", ":::::"};

    static NumberFormat lnf = new DecimalFormat("00E0");

    /**
     * create a string representation whose gray value appears as an indicator
     * of magnitude, cf. Hinton diagrams in statistics.
     *
     * @param d   value
     * @param max maximum value
     * @return
     */
    public static String shadeDouble(double d, double max)
    {
        int a = (int) Math.floor(d * 10 / max + 0.5);
        if (a > 10 || a < 0)
        {
            String x = lnf.format(d);
            a = 5 - x.length();
            for (int i = 0; i < a; i++)
            {
                x += " ";
            }
            return "<" + x + ">";
        }
        return "[" + shades[a] + "]";
    }
}</span>

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章