/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.topics;
import java.util.*;
import java.util.logging.*;
import java.util.zip.*;
import java.io.*;
import java.text.NumberFormat;
import cc.mallet.topics.*;
import cc.mallet.types.*;
import cc.mallet.util.*;
import gnu.trove.*;
/**
* A non-parametric topic model that uses the "minimal path" assumption
* to reduce bookkeeping.
*
* @author David Mimno
*/
public class NPTopicModel implements Serializable {
private static Logger logger = MalletLogger.getLogger(NPTopicModel.class.getName());
// the training instances and their topic assignments
protected ArrayList<TopicAssignment> data;
// the alphabet for the input data
protected Alphabet alphabet;
// the alphabet for the topics
protected LabelAlphabet topicAlphabet;
// The largest topic ID seen so far
protected int maxTopic;
// The current number of topics
protected int numTopics;
// The size of the vocabulary
protected int numTypes;
// Prior parameters
protected double alpha;
protected double gamma;
protected double beta; // Prior on per-topic multinomial distribution over words
protected double betaSum;
public static final double DEFAULT_BETA = 0.01;
// Statistics needed for sampling.
protected TIntIntHashMap[] typeTopicCounts; // indexed by <feature index, topic index>
protected TIntIntHashMap tokensPerTopic; // indexed by <topic index>
// The number of documents that contain at least one
// token with a given topic.
protected TIntIntHashMap docsPerTopic;
protected int totalDocTopics = 0;
public int showTopicsInterval = 50;
public int wordsPerTopic = 10;
protected Randoms random;
protected NumberFormat formatter;
protected boolean printLogLikelihood = false;
/** @param alpha this parameter balances the local document topic counts with
* the global distribution over topics.
* @param gamma this parameter is the weight on a completely new, never-before-seen topic
* in the global distribution.
* @param beta this parameter controls the variability of the topic-word distributions
*/
public NPTopicModel (double alpha, double gamma, double beta) {
this.data = new ArrayList<TopicAssignment>();
this.topicAlphabet = AlphabetFactory.labelAlphabetOfSize(1);
this.alpha = alpha;
this.gamma = gamma;
this.beta = beta;
this.random = new Randoms();
tokensPerTopic = new TIntIntHashMap();
docsPerTopic = new TIntIntHashMap();
formatter = NumberFormat.getInstance();
formatter.setMaximumFractionDigits(5);
logger.info("Non-Parametric LDA");
}
public void setTopicDisplay(int interval, int n) {
this.showTopicsInterval = interval;
this.wordsPerTopic = n;
}
public void setRandomSeed(int seed) {
random = new Randoms(seed);
}
public void addInstances (InstanceList training, int initialTopics) {
alphabet = training.getDataAlphabet();
numTypes = alphabet.size();
betaSum = beta * numTypes;
typeTopicCounts = new TIntIntHashMap[numTypes];
for (int type=0; type < numTypes; type++) {
typeTopicCounts[type] = new TIntIntHashMap();
}
numTopics = initialTopics;
int doc = 0;
for (Instance instance : training) {
doc++;
TIntIntHashMap topicCounts = new TIntIntHashMap();
FeatureSequence tokens = (FeatureSequence) instance.getData();
LabelSequence topicSequence =
new LabelSequence(topicAlphabet, new int[ tokens.size() ]);
int[] topics = topicSequence.getFeatures();
for (int position = 0; position < tokens.size(); position++) {
int topic = random.nextInt(numTopics);
tokensPerTopic.adjustOrPutValue(topic, 1, 1);
topics[position] = topic;
// Keep track of the number of docs with at least one token
// in a given topic.
if (! topicCounts.containsKey(topic)) {
docsPerTopic.adjustOrPutValue(topic, 1, 1);
totalDocTopics++;
topicCounts.put(topic, 1);
}
else {
topicCounts.adjustValue(topic, 1);
}
int type = tokens.getIndexAtPosition(position);
typeTopicCounts[type].adjustOrPutValue(topic, 1, 1);
}
TopicAssignment t = new TopicAssignment (instance, topicSequence);
data.add (t);
}
maxTopic = numTopics - 1;
}
public void sample (int iterations) throws IOException {
for (int iteration = 1; iteration <= iterations; iteration++) {
long iterationStart = System.currentTimeMillis();
// Loop over every document in the corpus
for (int doc = 0; doc < data.size(); doc++) {
FeatureSequence tokenSequence =
(FeatureSequence) data.get(doc).instance.getData();
LabelSequence topicSequence =
(LabelSequence) data.get(doc).topicSequence;
sampleTopicsForOneDoc (tokenSequence, topicSequence);
}
long elapsedMillis = System.currentTimeMillis() - iterationStart;
logger.info(iteration + "\t" + elapsedMillis + "ms\t" + numTopics);
// Occasionally print more information
if (showTopicsInterval != 0 && iteration % showTopicsInterval == 0) {
logger.info("<" + iteration + "> #Topics: " + numTopics + "\n" +
topWords (wordsPerTopic));
}
}
}
protected void sampleTopicsForOneDoc (FeatureSequence tokenSequence,
FeatureSequence topicSequence) {
int[] topics = topicSequence.getFeatures();
TIntIntHashMap currentTypeTopicCounts;
int type, oldTopic, newTopic;
double topicWeightsSum;
int docLength = tokenSequence.getLength();
TIntIntHashMap localTopicCounts = new TIntIntHashMap();
// populate topic counts
for (int position = 0; position < docLength; position++) {
localTopicCounts.adjustOrPutValue(topics[position], 1, 1);
}
double score, sum;
double[] topicTermScores = new double[numTopics + 1];
// Store a list of all the topics that currently exist.
int[] allTopics = docsPerTopic.keys();
// Iterate over the positions (words) in the document
for (int position = 0; position < docLength; position++) {
type = tokenSequence.getIndexAtPosition(position);
oldTopic = topics[position];
// Grab the relevant row from our two-dimensional array
currentTypeTopicCounts = typeTopicCounts[type];
// Remove this token from all counts.
int currentCount = localTopicCounts.get(oldTopic);
// Was this the only token of this topic in the doc?
if (currentCount == 1) {
localTopicCounts.remove(oldTopic);
// Was this the only doc with this topic?
int docCount = docsPerTopic.get(oldTopic);
if (docCount == 1) {
// This should be the very last token
assert(tokensPerTopic.get(oldTopic) == 1);
// Get rid of the topic
docsPerTopic.remove(oldTopic);
totalDocTopics--;
tokensPerTopic.remove(oldTopic);
numTopics--;
allTopics = docsPerTopic.keys();
topicTermScores = new double[numTopics + 1];
}
else {
// This is the last in the doc, but the topic still exists
docsPerTopic.adjustValue(oldTopic, -1);
totalDocTopics--;
tokensPerTopic.adjustValue(oldTopic, -1);
}
}
else {
// There is at least one other token in this doc
// with this topic.
localTopicCounts.adjustValue(oldTopic, -1);
tokensPerTopic.adjustValue(oldTopic, -1);
}
if (currentTypeTopicCounts.get(oldTopic) == 1) {
currentTypeTopicCounts.remove(oldTopic);
}
else {
currentTypeTopicCounts.adjustValue(oldTopic, -1);
}
// Now calculate and add up the scores for each topic for this word
sum = 0.0;
// First do the topics that currently exist
for (int i = 0; i < numTopics; i++) {
int topic = allTopics[i];
topicTermScores[i] =
(localTopicCounts.get(topic) +
alpha * (docsPerTopic.get(topic) /
(totalDocTopics + gamma))) *
(currentTypeTopicCounts.get(topic) + beta) /
(tokensPerTopic.get(topic) + betaSum);
sum += topicTermScores[i];
}
// Add the weight for a new topic
topicTermScores[numTopics] =
alpha * gamma / ( numTypes * (totalDocTopics + gamma) );
sum += topicTermScores[numTopics];
// Choose a random point between 0 and the sum of all topic scores
double sample = random.nextUniform() * sum;
// Figure out which topic contains that point
newTopic = -1;
int i = -1;
while (sample > 0.0) {
i++;
sample -= topicTermScores[i];
}
if (i < numTopics) {
newTopic = allTopics[i];
topics[position] = newTopic;
currentTypeTopicCounts.adjustOrPutValue(newTopic, 1, 1);
tokensPerTopic.adjustValue(newTopic, 1);
if (localTopicCounts.containsKey(newTopic)) {
localTopicCounts.adjustValue(newTopic, 1);
}
else {
// This is not a new topic, but it is new for this doc.
localTopicCounts.put(newTopic, 1);
docsPerTopic.adjustValue(newTopic, 1);
totalDocTopics++;
}
}
else {
// completely new topic: first generate an ID
newTopic = maxTopic + 1;
maxTopic = newTopic;
numTopics++;
topics[position] = newTopic;
localTopicCounts.put(newTopic, 1);
docsPerTopic.put(newTopic, 1);
totalDocTopics++;
currentTypeTopicCounts.put(newTopic, 1);
tokensPerTopic.put(newTopic, 1);
allTopics = docsPerTopic.keys();
topicTermScores = new double[numTopics + 1];
}
}
}
//
// Methods for displaying and saving results
//
public String topWords (int numWords) {
StringBuilder output = new StringBuilder();
IDSorter[] sortedWords = new IDSorter[numTypes];
for (int topic: docsPerTopic.keys()) {
for (int type = 0; type < numTypes; type++) {
sortedWords[type] = new IDSorter(type, typeTopicCounts[type].get(topic));
}
Arrays.sort(sortedWords);
output.append(topic + "\t" + tokensPerTopic.get(topic) + "\t");
for (int i=0; i < numWords; i++) {
if (sortedWords[i].getWeight() < 1.0) {
break;
}
output.append(alphabet.lookupObject(sortedWords[i].getID()) + " ");
}
output.append("\n");
}
return output.toString();
}
public void printState (File f) throws IOException {
PrintStream out =
new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))));
printState(out);
out.close();
}
public void printState (PrintStream out) {
out.println ("#doc source pos typeindex type topic");
for (int doc = 0; doc < data.size(); doc++) {
FeatureSequence tokenSequence = (FeatureSequence) data.get(doc).instance.getData();
LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequence;
String source = "NA";
if (data.get(doc).instance.getSource() != null) {
source = data.get(doc).instance.getSource().toString();
}
for (int position = 0; position < topicSequence.getLength(); position++) {
int type = tokenSequence.getIndexAtPosition(position);
int topic = topicSequence.getIndexAtPosition(position);
out.print(doc); out.print(' ');
out.print(source); out.print(' ');
out.print(position); out.print(' ');
out.print(type); out.print(' ');
out.print(alphabet.lookupObject(type)); out.print(' ');
out.print(topic); out.println();
}
}
}
public static void main (String[] args) throws IOException {
InstanceList training = InstanceList.load (new File(args[0]));
int numTopics = args.length > 1 ? Integer.parseInt(args[1]) : 200;
NPTopicModel lda = new NPTopicModel (5.0, 10.0, 0.1);
lda.addInstances(training, numTopics);
lda.sample(1000);
}
}