/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.topics; import java.util.*; import java.util.logging.*; import java.util.zip.*; import java.io.*; import java.text.NumberFormat; import cc.mallet.topics.*; import cc.mallet.types.*; import cc.mallet.util.*; /** * A simple implementation of Latent Dirichlet Allocation using Gibbs sampling. * This code is slower than the regular Mallet LDA implementation, but provides a * better starting place for understanding how sampling works and for * building new topic models. * * @author David Mimno, Andrew McCallum */ public class SimpleLDA implements Serializable { private static Logger logger = MalletLogger.getLogger(SimpleLDA.class.getName()); // the training instances and their topic assignments protected ArrayList<TopicAssignment> data; // the alphabet for the input data protected Alphabet alphabet; // the alphabet for the topics protected LabelAlphabet topicAlphabet; // The number of topics requested protected int numTopics; // The size of the vocabulary protected int numTypes; // Prior parameters protected double alpha; // Dirichlet(alpha,alpha,...) is the distribution over topics protected double alphaSum; protected double beta; // Prior on per-topic multinomial distribution over words protected double betaSum; public static final double DEFAULT_BETA = 0.01; // An array to put the topic counts for the current document. // Initialized locally below. Defined here to avoid // garbage collection overhead. protected int[] oneDocTopicCounts; // indexed by <document index, topic index> // Statistics needed for sampling. protected int[][] typeTopicCounts; // indexed by <feature index, topic index> protected int[] tokensPerTopic; // indexed by <topic index> public int showTopicsInterval = 50; public int wordsPerTopic = 10; protected Randoms random; protected NumberFormat formatter; protected boolean printLogLikelihood = false; public SimpleLDA (int numberOfTopics) { this (numberOfTopics, numberOfTopics, DEFAULT_BETA); } public SimpleLDA (int numberOfTopics, double alphaSum, double beta) { this (numberOfTopics, alphaSum, beta, new Randoms()); } private static LabelAlphabet newLabelAlphabet (int numTopics) { LabelAlphabet ret = new LabelAlphabet(); for (int i = 0; i < numTopics; i++) ret.lookupIndex("topic"+i); return ret; } public SimpleLDA (int numberOfTopics, double alphaSum, double beta, Randoms random) { this (newLabelAlphabet (numberOfTopics), alphaSum, beta, random); } public SimpleLDA (LabelAlphabet topicAlphabet, double alphaSum, double beta, Randoms random) { this.data = new ArrayList<TopicAssignment>(); this.topicAlphabet = topicAlphabet; this.numTopics = topicAlphabet.size(); this.alphaSum = alphaSum; this.alpha = alphaSum / numTopics; this.beta = beta; this.random = random; oneDocTopicCounts = new int[numTopics]; tokensPerTopic = new int[numTopics]; formatter = NumberFormat.getInstance(); formatter.setMaximumFractionDigits(5); logger.info("Simple LDA: " + numTopics + " topics"); } public Alphabet getAlphabet() { return alphabet; } public LabelAlphabet getTopicAlphabet() { return topicAlphabet; } public int getNumTopics() { return numTopics; } public ArrayList<TopicAssignment> getData() { return data; } public void setTopicDisplay(int interval, int n) { this.showTopicsInterval = interval; this.wordsPerTopic = n; } public void setRandomSeed(int seed) { random = new Randoms(seed); } public int[][] getTypeTopicCounts() { return typeTopicCounts; } public int[] getTopicTotals() { return tokensPerTopic; } public void addInstances (InstanceList training) { alphabet = training.getDataAlphabet(); numTypes = alphabet.size(); betaSum = beta * numTypes; typeTopicCounts = new int[numTypes][numTopics]; int doc = 0; for (Instance instance : training) { doc++; FeatureSequence tokens = (FeatureSequence) instance.getData(); LabelSequence topicSequence = new LabelSequence(topicAlphabet, new int[ tokens.size() ]); int[] topics = topicSequence.getFeatures(); for (int position = 0; position < tokens.size(); position++) { int topic = random.nextInt(numTopics); topics[position] = topic; tokensPerTopic[topic]++; int type = tokens.getIndexAtPosition(position); typeTopicCounts[type][topic]++; } TopicAssignment t = new TopicAssignment (instance, topicSequence); data.add (t); } } public void sample (int iterations) throws IOException { for (int iteration = 1; iteration <= iterations; iteration++) { long iterationStart = System.currentTimeMillis(); // Loop over every document in the corpus for (int doc = 0; doc < data.size(); doc++) { FeatureSequence tokenSequence = (FeatureSequence) data.get(doc).instance.getData(); LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequence; sampleTopicsForOneDoc (tokenSequence, topicSequence); } long elapsedMillis = System.currentTimeMillis() - iterationStart; logger.fine(iteration + "\t" + elapsedMillis + "ms\t"); // Occasionally print more information if (showTopicsInterval != 0 && iteration % showTopicsInterval == 0) { logger.info("<" + iteration + "> Log Likelihood: " + modelLogLikelihood() + "\n" + topWords (wordsPerTopic)); } } } protected void sampleTopicsForOneDoc (FeatureSequence tokenSequence, FeatureSequence topicSequence) { int[] oneDocTopics = topicSequence.getFeatures(); int[] currentTypeTopicCounts; int type, oldTopic, newTopic; double topicWeightsSum; int docLength = tokenSequence.getLength(); int[] localTopicCounts = new int[numTopics]; // populate topic counts for (int position = 0; position < docLength; position++) { localTopicCounts[oneDocTopics[position]]++; } double score, sum; double[] topicTermScores = new double[numTopics]; // Iterate over the positions (words) in the document for (int position = 0; position < docLength; position++) { type = tokenSequence.getIndexAtPosition(position); oldTopic = oneDocTopics[position]; // Grab the relevant row from our two-dimensional array currentTypeTopicCounts = typeTopicCounts[type]; // Remove this token from all counts. localTopicCounts[oldTopic]--; tokensPerTopic[oldTopic]--; assert(tokensPerTopic[oldTopic] >= 0) : "old Topic " + oldTopic + " below 0"; currentTypeTopicCounts[oldTopic]--; // Now calculate and add up the scores for each topic for this word sum = 0.0; // Here's where the math happens! Note that overall performance is // dominated by what you do in this loop. for (int topic = 0; topic < numTopics; topic++) { score = (alpha + localTopicCounts[topic]) * ((beta + currentTypeTopicCounts[topic]) / (betaSum + tokensPerTopic[topic])); sum += score; topicTermScores[topic] = score; } // Choose a random point between 0 and the sum of all topic scores double sample = random.nextUniform() * sum; // Figure out which topic contains that point newTopic = -1; while (sample > 0.0) { newTopic++; sample -= topicTermScores[newTopic]; } // Make sure we actually sampled a topic if (newTopic == -1) { throw new IllegalStateException ("SimpleLDA: New topic not sampled."); } // Put that new topic into the counts oneDocTopics[position] = newTopic; localTopicCounts[newTopic]++; tokensPerTopic[newTopic]++; currentTypeTopicCounts[newTopic]++; } } public double modelLogLikelihood() { double logLikelihood = 0.0; int nonZeroTopics; // The likelihood of the model is a combination of a // Dirichlet-multinomial for the words in each topic // and a Dirichlet-multinomial for the topics in each // document. // The likelihood function of a dirichlet multinomial is // Gamma( sum_i alpha_i ) prod_i Gamma( alpha_i + N_i ) // prod_i Gamma( alpha_i ) Gamma( sum_i (alpha_i + N_i) ) // So the log likelihood is // logGamma ( sum_i alpha_i ) - logGamma ( sum_i (alpha_i + N_i) ) + // sum_i [ logGamma( alpha_i + N_i) - logGamma( alpha_i ) ] // Do the documents first int[] topicCounts = new int[numTopics]; double[] topicLogGammas = new double[numTopics]; int[] docTopics; for (int topic=0; topic < numTopics; topic++) { topicLogGammas[ topic ] = Dirichlet.logGamma( alpha ); } for (int doc=0; doc < data.size(); doc++) { LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequence; docTopics = topicSequence.getFeatures(); for (int token=0; token < docTopics.length; token++) { topicCounts[ docTopics[token] ]++; } for (int topic=0; topic < numTopics; topic++) { if (topicCounts[topic] > 0) { logLikelihood += (Dirichlet.logGamma(alpha + topicCounts[topic]) - topicLogGammas[ topic ]); } } // subtract the (count + parameter) sum term logLikelihood -= Dirichlet.logGamma(alphaSum + docTopics.length); Arrays.fill(topicCounts, 0); } // add the parameter sum term logLikelihood += data.size() * Dirichlet.logGamma(alphaSum); // And the topics // Count the number of type-topic pairs int nonZeroTypeTopics = 0; for (int type=0; type < numTypes; type++) { // reuse this array as a pointer topicCounts = typeTopicCounts[type]; for (int topic = 0; topic < numTopics; topic++) { if (topicCounts[topic] == 0) { continue; } nonZeroTypeTopics++; logLikelihood += Dirichlet.logGamma(beta + topicCounts[topic]); if (Double.isNaN(logLikelihood)) { System.out.println(topicCounts[topic]); System.exit(1); } } } for (int topic=0; topic < numTopics; topic++) { logLikelihood -= Dirichlet.logGamma( (beta * numTopics) + tokensPerTopic[ topic ] ); if (Double.isNaN(logLikelihood)) { System.out.println("after topic " + topic + " " + tokensPerTopic[ topic ]); System.exit(1); } } logLikelihood += (Dirichlet.logGamma(beta * numTopics)) - (Dirichlet.logGamma(beta) * nonZeroTypeTopics); if (Double.isNaN(logLikelihood)) { System.out.println("at the end"); System.exit(1); } return logLikelihood; } // // Methods for displaying and saving results // public String topWords (int numWords) { StringBuilder output = new StringBuilder(); IDSorter[] sortedWords = new IDSorter[numTypes]; for (int topic = 0; topic < numTopics; topic++) { for (int type = 0; type < numTypes; type++) { sortedWords[type] = new IDSorter(type, typeTopicCounts[type][topic]); } Arrays.sort(sortedWords); output.append(topic + "\t" + tokensPerTopic[topic] + "\t"); for (int i=0; i < numWords; i++) { output.append(alphabet.lookupObject(sortedWords[i].getID()) + " "); } output.append("\n"); } return output.toString(); } /** * @param file The filename to print to * @param threshold Only print topics with proportion greater than this number * @param max Print no more than this many topics */ public void printDocumentTopics (File file, double threshold, int max) throws IOException { PrintWriter out = new PrintWriter(file); out.print ("#doc source topic proportion ...\n"); int docLen; int[] topicCounts = new int[ numTopics ]; IDSorter[] sortedTopics = new IDSorter[ numTopics ]; for (int topic = 0; topic < numTopics; topic++) { // Initialize the sorters with dummy values sortedTopics[topic] = new IDSorter(topic, topic); } if (max < 0 || max > numTopics) { max = numTopics; } for (int doc = 0; doc < data.size(); doc++) { LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequence; int[] currentDocTopics = topicSequence.getFeatures(); out.print (doc); out.print (' '); if (data.get(doc).instance.getSource() != null) { out.print (data.get(doc).instance.getSource()); } else { out.print ("null-source"); } out.print (' '); docLen = currentDocTopics.length; // Count up the tokens for (int token=0; token < docLen; token++) { topicCounts[ currentDocTopics[token] ]++; } // And normalize for (int topic = 0; topic < numTopics; topic++) { sortedTopics[topic].set(topic, (float) topicCounts[topic] / docLen); } Arrays.sort(sortedTopics); for (int i = 0; i < max; i++) { if (sortedTopics[i].getWeight() < threshold) { break; } out.print (sortedTopics[i].getID() + " " + sortedTopics[i].getWeight() + " "); } out.print (" \n"); Arrays.fill(topicCounts, 0); } } public void printState (File f) throws IOException { PrintStream out = new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f)))); printState(out); out.close(); } public void printState (PrintStream out) { out.println ("#doc source pos typeindex type topic"); for (int doc = 0; doc < data.size(); doc++) { FeatureSequence tokenSequence = (FeatureSequence) data.get(doc).instance.getData(); LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequence; String source = "NA"; if (data.get(doc).instance.getSource() != null) { source = data.get(doc).instance.getSource().toString(); } for (int position = 0; position < topicSequence.getLength(); position++) { int type = tokenSequence.getIndexAtPosition(position); int topic = topicSequence.getIndexAtPosition(position); out.print(doc); out.print(' '); out.print(source); out.print(' '); out.print(position); out.print(' '); out.print(type); out.print(' '); out.print(alphabet.lookupObject(type)); out.print(' '); out.print(topic); out.println(); } } } // Serialization private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private static final int NULL_INTEGER = -1; public void write (File f) { try { ObjectOutputStream oos = new ObjectOutputStream (new FileOutputStream(f)); oos.writeObject(this); oos.close(); } catch (IOException e) { System.err.println("Exception writing file " + f + ": " + e); } } private void writeObject (ObjectOutputStream out) throws IOException { out.writeInt (CURRENT_SERIAL_VERSION); // Instance lists out.writeObject (data); out.writeObject (alphabet); out.writeObject (topicAlphabet); out.writeInt (numTopics); out.writeObject (alpha); out.writeDouble (beta); out.writeDouble (betaSum); out.writeInt(showTopicsInterval); out.writeInt(wordsPerTopic); out.writeObject(random); out.writeObject(formatter); out.writeBoolean(printLogLikelihood); out.writeObject (typeTopicCounts); for (int ti = 0; ti < numTopics; ti++) { out.writeInt (tokensPerTopic[ti]); } } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int featuresLength; int version = in.readInt (); data = (ArrayList<TopicAssignment>) in.readObject (); alphabet = (Alphabet) in.readObject(); topicAlphabet = (LabelAlphabet) in.readObject(); numTopics = in.readInt(); alpha = in.readDouble(); alphaSum = alpha * numTopics; beta = in.readDouble(); betaSum = in.readDouble(); showTopicsInterval = in.readInt(); wordsPerTopic = in.readInt(); random = (Randoms) in.readObject(); formatter = (NumberFormat) in.readObject(); printLogLikelihood = in.readBoolean(); int numDocs = data.size(); this.numTypes = alphabet.size(); typeTopicCounts = (int[][]) in.readObject(); tokensPerTopic = new int[numTopics]; for (int ti = 0; ti < numTopics; ti++) { tokensPerTopic[ti] = in.readInt(); } } public static void main (String[] args) throws IOException { InstanceList training = InstanceList.load (new File(args[0])); int numTopics = args.length > 1 ? Integer.parseInt(args[1]) : 200; SimpleLDA lda = new SimpleLDA (numTopics, 50.0, 0.01); lda.addInstances(training); lda.sample(1000); } }