/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.clustering.lda;
import java.util.Iterator;
import org.apache.commons.math.special.Gamma;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleDoubleFunction;
/**
* Class for performing infererence on a document, which involves computing (an approximation to)
* p(word|topic) for each word and topic, and a prior distribution p(topic) for each topic.
*/
public class LDAInference {
private static final double E_STEP_CONVERGENCE = 1.0E-6;
private static final int MAX_ITER = 20;
private DenseMatrix phi;
private final LDAState state;
public LDAInference(LDAState state) {
this.state = state;
}
/**
* An estimate of the probabilities for each document. Gamma(k) is the probability of seeing topic k in the
* document, phi(k,w) is the (log) probability of topic k generating w in this document.
*/
public static class InferredDocument {
private final Vector wordCounts;
private final Vector gamma; // p(topic)
private final Matrix mphi; // log p(columnMap(w)|t)
private final int[] columnMap; // maps words into the matrix's column map
private final double logLikelihood;
InferredDocument(Vector wordCounts, Vector gamma, int[] columnMap, Matrix phi, double ll) {
this.wordCounts = wordCounts;
this.gamma = gamma;
this.mphi = phi;
this.columnMap = columnMap;
this.logLikelihood = ll;
}
public double phi(int k, int w) {
return mphi.getQuick(k, columnMap[w]);
}
public Vector getWordCounts() {
return wordCounts;
}
public Vector getGamma() {
return gamma;
}
public double getLogLikelihood() {
return logLikelihood;
}
}
/**
* Performs inference on the given document, returning an InferredDocument.
*/
public InferredDocument infer(Vector wordCounts) {
double docTotal = wordCounts.zSum();
int docLength = wordCounts.size(); // cardinality of document vectors
// initialize variational approximation to p(z|doc)
Vector gamma = new DenseVector(state.getNumTopics());
gamma.assign(state.getTopicSmoothing() + docTotal / state.getNumTopics());
Vector nextGamma = new DenseVector(state.getNumTopics());
createPhiMatrix(docLength);
Vector digammaGamma = digammaGamma(gamma);
int[] map = new int[docLength];
int iteration = 0;
boolean converged = false;
double oldLL = 1.0;
while (!converged && iteration < MAX_ITER) {
nextGamma.assign(state.getTopicSmoothing()); // nG := alpha, for all topics
int mapping = 0;
for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero(); iter.hasNext();) {
Vector.Element e = iter.next();
int word = e.index();
Vector phiW = eStepForWord(word, digammaGamma);
phi.assignColumn(mapping, phiW);
if (iteration == 0) { // first iteration
map[word] = mapping;
}
for (int k = 0; k < nextGamma.size(); ++k) {
double g = nextGamma.getQuick(k);
nextGamma.setQuick(k, g + e.get() * Math.exp(phiW.getQuick(k)));
}
mapping++;
}
Vector tempG = gamma;
gamma = nextGamma;
nextGamma = tempG;
digammaGamma = digammaGamma(gamma);
double ll = computeLikelihood(wordCounts, map, phi, gamma, digammaGamma);
// isNotNaNAssertion(ll);
converged = oldLL < 0.0 && (oldLL - ll) / oldLL < E_STEP_CONVERGENCE;
oldLL = ll;
iteration++;
}
return new InferredDocument(wordCounts, gamma, map, phi, oldLL);
}
/**
* @param gamma
* @return a vector whose entries are digamma(oldEntry) - digamma(gamma.zSum())
*/
private Vector digammaGamma(Vector gamma) {
// digamma is expensive, precompute
Vector digammaGamma = digamma(gamma);
// and log normalize:
double digammaSumGamma = digamma(gamma.zSum());
for (int i = 0; i < state.getNumTopics(); i++) {
digammaGamma.setQuick(i, digammaGamma.getQuick(i) - digammaSumGamma);
}
return digammaGamma;
}
private void createPhiMatrix(int docLength) {
if (phi == null || phi.rowSize() != docLength) {
phi = new DenseMatrix(state.getNumTopics(), docLength);
} else {
phi.assign(0);
}
}
/**
* diGamma(x) = gamma'(x)/gamma(x)
* logGamma(x) = log(gamma(x))
*
* ll = log(gamma(smooth*numTop) / smooth^numTop) +
* sum_{i < numTop} (smooth - g[i])*(digamma(g[i]) - digamma(|g|)) + log(gamma(g[i])
* Computes the log likelihood of the wordCounts vector, given \phi, \gamma, and \digamma(gamma)
* @param wordCounts
* @param map
* @param phi
* @param gamma
* @param digammaGamma
* @return
*/
private double computeLikelihood(Vector wordCounts, int[] map, Matrix phi, Vector gamma, Vector digammaGamma) {
double ll = 0.0;
// log normalizer for q(gamma);
ll += Gamma.logGamma(state.getTopicSmoothing() * state.getNumTopics());
ll -= state.getNumTopics() * Gamma.logGamma(state.getTopicSmoothing());
// isNotNaNAssertion(ll);
// now for the the rest of q(gamma);
for (int k = 0; k < state.getNumTopics(); ++k) {
double gammaK = gamma.get(k);
ll += (state.getTopicSmoothing() - gammaK) * digammaGamma.getQuick(k);
ll += Gamma.logGamma(gammaK);
}
ll -= Gamma.logGamma(gamma.zSum());
// isNotNaNAssertion(ll);
// for each word
for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero(); iter.hasNext();) {
Vector.Element e = iter.next();
int w = e.index();
double n = e.get();
int mapping = map[w];
// now for each topic:
for (int k = 0; k < state.getNumTopics(); k++) {
double llPart = 0.0;
double phiKMapping = phi.getQuick(k, mapping);
llPart += Math.exp(phiKMapping)
* (digammaGamma.getQuick(k) - phiKMapping + state.logProbWordGivenTopic(w, k));
ll += llPart * n;
// likelihoodAssertion(w, k, llPart);
}
}
// isLessThanOrEqualsZero(ll);
return ll;
}
/**
* Compute log q(k|w,doc) for each topic k, for a given word.
*/
private Vector eStepForWord(int word, Vector digammaGamma) {
Vector phi = new DenseVector(state.getNumTopics()); // log q(k|w), for each w
double phiTotal = Double.NEGATIVE_INFINITY; // log Normalizer
for (int k = 0; k < state.getNumTopics(); ++k) { // update q(k|w)'s param phi
phi.setQuick(k, state.logProbWordGivenTopic(word, k) + digammaGamma.getQuick(k));
phiTotal = LDAUtil.logSum(phiTotal, phi.getQuick(k));
// assertions(word, digammaGamma, phiTotal, k);
}
for (int i = 0; i < state.getNumTopics(); i++) {
phi.setQuick(i, phi.getQuick(i) - phiTotal); // log normalize
}
return phi;
}
private static Vector digamma(Vector v) {
Vector digammaGamma = new DenseVector(v.size());
digammaGamma.assign(v, new DoubleDoubleFunction() {
@Override
public double apply(double unused, double g) {
return digamma(g);
}
});
return digammaGamma;
}
/**
* Approximation to the digamma function, from Radford Neal.
*
* Original License: Copyright (c) 1995-2003 by Radford M. Neal
*
* Permission is granted for anyone to copy, use, modify, or distribute this program and accompanying
* programs and documents for any purpose, provided this copyright notice is retained and prominently
* displayed, along with a note saying that the original programs are available from Radford Neal's web
* page, and note is made of any changes made to the programs. The programs and documents are distributed
* without any warranty, express or implied. As the programs were written for research purposes only, they
* have not been tested to the degree that would be advisable in any important application. All use of these
* programs is entirely at the user's own risk.
*
*
* Ported to Java for Mahout.
*
*/
private static double digamma(double x) {
double r = 0.0;
while (x <= 5) {
r -= 1 / x;
x += 1;
}
double f = 1.0 / (x * x);
double t = f * (-1.0 / 12.0 + f * (1.0 / 120.0 + f * (-1.0 / 252.0 + f * (1.0 / 240.0
+ f * (-1.0 / 132.0 + f * (691.0 / 32760.0 + f * (-1.0 / 12.0 + f * 3617.0 / 8160.0)))))));
return r + Math.log(x) - 0.5 / x + t;
}
}