<span style="font-size:18px;">/*
* (C) Copyright 2005, Gregor Heinrich (gregor :: arbylon : net) (This file is
* part of the org.knowceans experimental software packages.)
*/
/*
* LdaGibbsSampler is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the Free
* Software Foundation; either version 2 of the License, or (at your option) any
* later version.
*/
/*
* LdaGibbsSampler is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*/
/*
* You should have received a copy of the GNU General Public License along with
* this program; if not, write to the Free Software Foundation, Inc., 59 Temple
* Place, Suite 330, Boston, MA 02111-1307 USA
*/
/*
* Created on Mar 6, 2005
*/
package lda;
import java.text.DecimalFormat;
import java.text.NumberFormat;
/**
* Gibbs sampler for estimating the best assignments of topics for words and
* documents in a corpus. The algorithm is introduced in Tom Griffiths' paper
* "Gibbs sampling in the generative model of Latent Dirichlet Allocation"
* (2002).<br>
* Gibbs sampler採樣算法的實現
*
* @author heinrich
*/
public class LdaGibbsSampler
{
/**
* document data (term lists)<br>
* 文檔
*/
int[][] documents;
/**
* vocabulary size<br>
* 詞表大小
*/
int V;
/**
* number of topics<br>
* 主題數目
*/
int K;
/**
* Dirichlet parameter (document--topic associations)<br>
* 文檔——主題參數
*/
double alpha = 2.0;
/**
* Dirichlet parameter (topic--term associations)<br>
* 主題——詞語參數
*/
double beta = 0.5;
/**
* topic assignments for each word.<br>
* 每個詞語的主題 z[i][j] := 文檔i的第j個詞語的主題編號
*/
int z[][];
/**
* cwt[i][j] number of instances of word i (term?) assigned to topic j.<br>
* 計數器,nw[i][j] := 詞語i歸入主題j的次數
*/
int[][] nw;
/**
* na[i][j] number of words in document i assigned to topic j.<br>
* 計數器,nd[i][j] := 文檔[i]中歸入主題j的詞語的個數
*/
int[][] nd;
/**
* nwsum[j] total number of words assigned to topic j.<br>
* 計數器,nwsum[j] := 歸入主題j詞語的個數
*/
int[] nwsum;
/**
* nasum[i] total number of words in document i.<br>
* 計數器,ndsum[i] := 文檔i中全部詞語的數量
*/
int[] ndsum;
/**
* cumulative statistics of theta<br>
* theta的累積量
*/
double[][] thetasum;
/**
* cumulative statistics of phi<br>
* phi的累積量
*/
double[][] phisum;
/**
* size of statistics<br>
* 樣本容量
*/
int numstats;
/**
* sampling lag (?)<br>
* 多久更新一次統計量
*/
private static int THIN_INTERVAL = 20;
/**
* burn-in period<br>
* 收斂前的迭代次數
*/
private static int BURN_IN = 100;
/**
* max iterations<br>
* 最大迭代次數
*/
private static int ITERATIONS = 1000;
/**
* sample lag (if -1 only one sample taken)<br>
* 最後的模型個數(取收斂後的n個迭代的參數做平均可以使得模型質量更高)
*/
private static int SAMPLE_LAG = 10;
private static int dispcol = 0;
/**
* Initialise the Gibbs sampler with data.<br>
* 用數據初始化採樣器
*
* @param documents 文檔
* @param V vocabulary size 詞表大小
*/
public LdaGibbsSampler(int[][] documents, int V)
{
this.documents = documents;
this.V = V;
}
/**
* Initialisation: Must start with an assignment of observations to topics ?
* Many alternatives are possible, I chose to perform random assignments
* with equal probabilities<br>
* 隨機初始化狀態
*
* @param K number of topics K個主題
*/
public void initialState(int K)
{
int M = documents.length;
// initialise count variables. 初始化計數器
nw = new int[V][K];
nd = new int[M][K];
nwsum = new int[K];
ndsum = new int[M];
// The z_i are are initialised to values in [1,K] to determine the
// initial state of the Markov chain.
z = new int[M][]; // z_i := 1到K之間的值,表示馬氏鏈的初始狀態
for (int m = 0; m < M; m++)
{
int N = documents[m].length;
z[m] = new int[N];
for (int n = 0; n < N; n++)
{
int topic = (int) (Math.random() * K);
z[m][n] = topic;
// number of instances of word i assigned to topic j
nw[documents[m][n]][topic]++;
// number of words in document i assigned to topic j.
nd[m][topic]++;
// total number of words assigned to topic j.
nwsum[topic]++;
}
// total number of words in document i
ndsum[m] = N;
}
}
public void gibbs(int K)
{
gibbs(K, 2.0, 0.5);
}
/**
* Main method: Select initial state ? Repeat a large number of times: 1.
* Select an element 2. Update conditional on other elements. If
* appropriate, output summary for each run.<br>
* 採樣
*
* @param K number of topics 主題數
* @param alpha symmetric prior parameter on document--topic associations 對稱文檔——主題先驗概率?
* @param beta symmetric prior parameter on topic--term associations 對稱主題——詞語先驗概率?
*/
public void gibbs(int K, double alpha, double beta)
{
this.K = K;
this.alpha = alpha;
this.beta = beta;
// init sampler statistics 分配內存
if (SAMPLE_LAG > 0)
{
thetasum = new double[documents.length][K];
phisum = new double[K][V];
numstats = 0;
}
// initial state of the Markov chain:
initialState(K);
System.out.println("Sampling " + ITERATIONS
+ " iterations with burn-in of " + BURN_IN + " (B/S="
+ THIN_INTERVAL + ").");
for (int i = 0; i < ITERATIONS; i++)
{
// for all z_i
for (int m = 0; m < z.length; m++)
{
for (int n = 0; n < z[m].length; n++)
{
// (z_i = z[m][n])
// sample from p(z_i|z_-i, w)
int topic = sampleFullConditional(m, n);
z[m][n] = topic;
}
}
if ((i < BURN_IN) && (i % THIN_INTERVAL == 0))
{
System.out.print("B");
dispcol++;
}
// display progress
if ((i > BURN_IN) && (i % THIN_INTERVAL == 0))
{
System.out.print("S");
dispcol++;
}
// get statistics after burn-in
if ((i > BURN_IN) && (SAMPLE_LAG > 0) && (i % SAMPLE_LAG == 0))
{
updateParams();
System.out.print("|");
if (i % THIN_INTERVAL != 0)
dispcol++;
}
if (dispcol >= 100)
{
System.out.println();
dispcol = 0;
}
}
System.out.println();
}
/**
* Sample a topic z_i from the full conditional distribution: p(z_i = j |
* z_-i, w) = (n_-i,j(w_i) + beta)/(n_-i,j(.) + W * beta) * (n_-i,j(d_i) +
* alpha)/(n_-i,.(d_i) + K * alpha) <br>
* 根據上述公式計算文檔m中第n個詞語的主題的完全條件分佈,輸出最可能的主題
*
* @param m document
* @param n word
*/
private int sampleFullConditional(int m, int n)
{
// remove z_i from the count variables 先將這個詞從計數器中抹掉
int topic = z[m][n];
nw[documents[m][n]][topic]--;
nd[m][topic]--;
nwsum[topic]--;
ndsum[m]--;
// do multinomial sampling via cumulative method: 通過多項式方法採樣多項式分佈
double[] p = new double[K];
for (int k = 0; k < K; k++)
{
p[k] = (nw[documents[m][n]][k] + beta) / (nwsum[k] + V * beta)
* (nd[m][k] + alpha) / (ndsum[m] + K * alpha);
}
// cumulate multinomial parameters 累加多項式分佈的參數
for (int k = 1; k < p.length; k++)
{
p[k] += p[k - 1];
}
// scaled sample because of unnormalised p[] 正則化
double u = Math.random() * p[K - 1];
for (topic = 0; topic < p.length; topic++)
{
if (u < p[topic])
break;
}
// add newly estimated z_i to count variables 將重新估計的該詞語加入計數器
nw[documents[m][n]][topic]++;
nd[m][topic]++;
nwsum[topic]++;
ndsum[m]++;
return topic;
}
/**
* Add to the statistics the values of theta and phi for the current state.<br>
* 更新參數
*/
private void updateParams()
{
for (int m = 0; m < documents.length; m++)
{
for (int k = 0; k < K; k++)
{
thetasum[m][k] += (nd[m][k] + alpha) / (ndsum[m] + K * alpha);
}
}
for (int k = 0; k < K; k++)
{
for (int w = 0; w < V; w++)
{
phisum[k][w] += (nw[w][k] + beta) / (nwsum[k] + V * beta);
}
}
numstats++;
}
/**
* Retrieve estimated document--topic associations. If sample lag > 0 then
* the mean value of all sampled statistics for theta[][] is taken.<br>
* 獲取文檔——主題矩陣
*
* @return theta multinomial mixture of document topics (M x K)
*/
public double[][] getTheta()
{
double[][] theta = new double[documents.length][K];
if (SAMPLE_LAG > 0)
{
for (int m = 0; m < documents.length; m++)
{
for (int k = 0; k < K; k++)
{
theta[m][k] = thetasum[m][k] / numstats;
}
}
}
else
{
for (int m = 0; m < documents.length; m++)
{
for (int k = 0; k < K; k++)
{
theta[m][k] = (nd[m][k] + alpha) / (ndsum[m] + K * alpha);
}
}
}
return theta;
}
/**
* Retrieve estimated topic--word associations. If sample lag > 0 then the
* mean value of all sampled statistics for phi[][] is taken.<br>
* 獲取主題——詞語矩陣
*
* @return phi multinomial mixture of topic words (K x V)
*/
public double[][] getPhi()
{
double[][] phi = new double[K][V];
if (SAMPLE_LAG > 0)
{
for (int k = 0; k < K; k++)
{
for (int w = 0; w < V; w++)
{
phi[k][w] = phisum[k][w] / numstats;
}
}
}
else
{
for (int k = 0; k < K; k++)
{
for (int w = 0; w < V; w++)
{
phi[k][w] = (nw[w][k] + beta) / (nwsum[k] + V * beta);
}
}
}
return phi;
}
/**
*輸出多項式數據
* @param data vector of evidence
* @param fmax max frequency in display
* @return the scaled histogram bin values
*/
public static void hist(double[] data, int fmax)
{
double[] hist = new double[data.length];
// scale maximum
double hmax = 0;
for (int i = 0; i < data.length; i++)
{
hmax = Math.max(data[i], hmax);
}
double shrink = fmax / hmax;
for (int i = 0; i < data.length; i++)
{
hist[i] = shrink * data[i];
}
NumberFormat nf = new DecimalFormat("00");
String scale = "";
for (int i = 1; i < fmax / 10 + 1; i++)
{
scale += " . " + i % 10;
}
System.out.println("x" + nf.format(hmax / fmax) + "\t0" + scale);
for (int i = 0; i < hist.length; i++)
{
System.out.print(i + "\t|");
for (int j = 0; j < Math.round(hist[i]); j++)
{
if ((j + 1) % 10 == 0)
System.out.print("]");
else
System.out.print("|");
}
System.out.println();
}
}
/**
* Configure the gibbs sampler<br>
* 配置採樣器
*
* @param iterations number of total iterations
* @param burnIn number of burn-in iterations
* @param thinInterval update statistics interval
* @param sampleLag sample interval (-1 for just one sample at the end)
*/
public void configure(int iterations, int burnIn, int thinInterval,
int sampleLag)
{
ITERATIONS = iterations;
BURN_IN = burnIn;
THIN_INTERVAL = thinInterval;
SAMPLE_LAG = sampleLag;
}
/**
* Inference a new document by a pre-trained phi matrix
*
* @param phi pre-trained phi matrix
* @param doc document
* @return a p array
*/
public static double[] inference(double alpha, double beta, double[][] phi, int[] doc)
{
int K = phi.length;
int V = phi[0].length;
// init
// initialise count variables. 初始化計數器
int[][] nw = new int[V][K];
int[] nd = new int[K];
int[] nwsum = new int[K];
int ndsum = 0;
// The z_i are are initialised to values in [1,K] to determine the
// initial state of the Markov chain.
int N = doc.length;
int[] z = new int[N]; // z_i := 1到K之間的值,表示馬氏鏈的初始狀態
for (int n = 0; n < N; n++)
{
int topic = (int) (Math.random() * K);
z[n] = topic;
// number of instances of word i assigned to topic j
nw[doc[n]][topic]++;
// number of words in document i assigned to topic j.
nd[topic]++;
// total number of words assigned to topic j.
nwsum[topic]++;
}
// total number of words in document i
ndsum = N;
for (int i = 0; i < ITERATIONS; i++)
{
for (int n = 0; n < z.length; n++)
{
// (z_i = z[m][n])
// sample from p(z_i|z_-i, w)
// remove z_i from the count variables 先將這個詞從計數器中抹掉
int topic = z[n];
nw[doc[n]][topic]--;
nd[topic]--;
nwsum[topic]--;
ndsum--;
// do multinomial sampling via cumulative method: 通過多項式方法採樣多項式分佈
double[] p = new double[K];
for (int k = 0; k < K; k++)
{
p[k] = (nw[doc[n]][k] + beta) / (nwsum[k] + V * beta)
* (nd[k] + alpha) / (ndsum + K * alpha);
}
// cumulate multinomial parameters 累加多項式分佈的參數
for (int k = 1; k < p.length; k++)
{
p[k] += p[k - 1];
}
// scaled sample because of unnormalised p[] 正則化
double u = Math.random() * p[K - 1];
for (topic = 0; topic < p.length; topic++)
{
if (u < p[topic])
break;
}
// add newly estimated z_i to count variables 將重新估計的該詞語加入計數器
nw[doc[n]][topic]++;
nd[topic]++;
nwsum[topic]++;
ndsum++;
z[n] = topic;
}
}
double[] theta = new double[K];
for (int k = 0; k < K; k++)
{
theta[k] = (nd[k] + alpha) / (ndsum + K * alpha);
}
return theta;
}
public static double[] inference(double[][] phi, int[] doc)
{
return inference(2.0, 0.5, phi, doc);
}
/**
* 測試入口
*/
public static void main(String[] args)
{
// words in documents
int[][] documents = {
{1, 4, 3, 2, 3, 1, 4, 3, 2, 3, 1, 4, 3, 2, 3, 6},
{2, 2, 4, 2, 4, 2, 2, 2, 2, 4, 2, 2},
{1, 6, 5, 6, 0, 1, 6, 5, 6, 0, 1, 6, 5, 6, 0, 0},
{5, 6, 6, 2, 3, 3, 6, 5, 6, 2, 2, 6, 5, 6, 6, 6, 0},
{2, 2, 4, 4, 4, 4, 1, 5, 5, 5, 5, 5, 5, 1, 1, 1, 1, 0},
{5, 4, 2, 3, 4, 5, 6, 6, 5, 4, 3, 2}}; // 文檔的詞語id集合
// vocabulary
int V = 7; // 詞表大小
int M = documents.length;
// # topics
int K = 2; // 主題數目
// good values alpha = 2, beta = .5
double alpha = 2;
double beta = .5;
System.out.println("Latent Dirichlet Allocation using Gibbs Sampling.");
LdaGibbsSampler lda = new LdaGibbsSampler(documents, V);
lda.configure(10000, 2000, 100, 10);
lda.gibbs(K, alpha, beta);
double[][] theta = lda.getTheta();
double[][] phi = lda.getPhi();
System.out.println();
System.out.println();
System.out.println("Document--Topic Associations, Theta[d][k] (alpha="
+ alpha + ")");
System.out.print("d\\k\t");
for (int m = 0; m < theta[0].length; m++)
{
System.out.print(" " + m % 10 + " ");
}
System.out.println();
for (int m = 0; m < theta.length; m++)
{
System.out.print(m + "\t");
for (int k = 0; k < theta[m].length; k++)
{
// System.out.print(theta[m][k] + " ");
System.out.print(shadeDouble(theta[m][k], 1) + " ");
}
System.out.println();
}
System.out.println();
System.out.println("Topic--Term Associations, Phi[k][w] (beta=" + beta
+ ")");
System.out.print("k\\w\t");
for (int w = 0; w < phi[0].length; w++)
{
System.out.print(" " + w % 10 + " ");
}
System.out.println();
for (int k = 0; k < phi.length; k++)
{
System.out.print(k + "\t");
for (int w = 0; w < phi[k].length; w++)
{
// System.out.print(phi[k][w] + " ");
System.out.print(shadeDouble(phi[k][w], 1) + " ");
}
System.out.println();
}
// Let's inference a new document
int[] aNewDocument = {2, 2, 4, 2, 4, 2, 2, 2, 2, 4, 2, 2};
double[] newTheta = inference(alpha, beta, phi, aNewDocument);
for (int k = 0; k < newTheta.length; k++)
{
// System.out.print(theta[m][k] + " ");
System.out.print(shadeDouble(newTheta[k], 1) + " ");
}
System.out.println();
}
static String[] shades = {" ", ". ", ": ", ":. ", ":: ",
"::. ", "::: ", ":::. ", ":::: ", "::::.", ":::::"};
static NumberFormat lnf = new DecimalFormat("00E0");
/**
* create a string representation whose gray value appears as an indicator
* of magnitude, cf. Hinton diagrams in statistics.
*
* @param d value
* @param max maximum value
* @return
*/
public static String shadeDouble(double d, double max)
{
int a = (int) Math.floor(d * 10 / max + 0.5);
if (a > 10 || a < 0)
{
String x = lnf.format(d);
a = 5 - x.length();
for (int i = 0; i < a; i++)
{
x += " ";
}
return "<" + x + ">";
}
return "[" + shades[a] + "]";
}
}</span>
吉布斯採樣器
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.