/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.topics; import gnu.trove.TIntIntHashMap; import java.util.Arrays; import java.util.List; import java.util.ArrayList; import java.util.TreeSet; import java.util.Iterator; import java.util.zip.*; import java.io.*; import java.text.NumberFormat; import cc.mallet.types.*; import cc.mallet.util.Randoms; /** * Latent Dirichlet Allocation with optimized hyperparameters * * @author David Mimno, Andrew McCallum * @deprecated Use ParallelTopicModel instead, which uses substantially faster data structures even for non-parallel operation. */ public class LDAHyper implements Serializable { // Analogous to a cc.mallet.classify.Classification public class Topication implements Serializable { public Instance instance; public LDAHyper model; public LabelSequence topicSequence; public Labeling topicDistribution; // not actually constructed by model fitting, but could be added for "test" documents. public Topication (Instance instance, LDAHyper model, LabelSequence topicSequence) { this.instance = instance; this.model = model; this.topicSequence = topicSequence; } // Maintainable serialization private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private void writeObject (ObjectOutputStream out) throws IOException { out.writeInt (CURRENT_SERIAL_VERSION); out.writeObject (instance); out.writeObject (model); out.writeObject (topicSequence); out.writeObject (topicDistribution); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt (); instance = (Instance) in.readObject(); model = (LDAHyper) in.readObject(); topicSequence = (LabelSequence) in.readObject(); topicDistribution = (Labeling) in.readObject(); } } protected ArrayList<Topication> data; // the training instances and their topic assignments protected Alphabet alphabet; // the alphabet for the input data protected LabelAlphabet topicAlphabet; // the alphabet for the topics protected int numTopics; // Number of topics to be fit protected int numTypes; protected double[] alpha; // Dirichlet(alpha,alpha,...) is the distribution over topics protected double alphaSum; protected double beta; // Prior on per-topic multinomial distribution over words protected double betaSum; public static final double DEFAULT_BETA = 0.01; protected double smoothingOnlyMass = 0.0; protected double[] cachedCoefficients; int topicTermCount = 0; int betaTopicCount = 0; int smoothingOnlyCount = 0; // Instance list for empirical likelihood calculation protected InstanceList testing = null; // An array to put the topic counts for the current document. // Initialized locally below. Defined here to avoid // garbage collection overhead. protected int[] oneDocTopicCounts; // indexed by <document index, topic index> protected gnu.trove.TIntIntHashMap[] typeTopicCounts; // indexed by <feature index, topic index> protected int[] tokensPerTopic; // indexed by <topic index> // for dirichlet estimation protected int[] docLengthCounts; // histogram of document sizes protected int[][] topicDocCounts; // histogram of document/topic counts, indexed by <topic index, sequence position index> public int iterationsSoFar = 0; public int numIterations = 1000; public int burninPeriod = 20; // was 50; //was 200; public int saveSampleInterval = 5; // was 10; public int optimizeInterval = 20; // was 50; public int showTopicsInterval = 10; // was 50; public int wordsPerTopic = 7; protected int outputModelInterval = 0; protected String outputModelFilename; protected int saveStateInterval = 0; protected String stateFilename = null; protected Randoms random; protected NumberFormat formatter; protected boolean printLogLikelihood = false; public LDAHyper (int numberOfTopics) { this (numberOfTopics, numberOfTopics, DEFAULT_BETA); } public LDAHyper (int numberOfTopics, double alphaSum, double beta) { this (numberOfTopics, alphaSum, beta, new Randoms()); } private static LabelAlphabet newLabelAlphabet (int numTopics) { LabelAlphabet ret = new LabelAlphabet(); for (int i = 0; i < numTopics; i++) ret.lookupIndex("topic"+i); return ret; } public LDAHyper (int numberOfTopics, double alphaSum, double beta, Randoms random) { this (newLabelAlphabet (numberOfTopics), alphaSum, beta, random); } public LDAHyper (LabelAlphabet topicAlphabet, double alphaSum, double beta, Randoms random) { this.data = new ArrayList<Topication>(); this.topicAlphabet = topicAlphabet; this.numTopics = topicAlphabet.size(); this.alphaSum = alphaSum; this.alpha = new double[numTopics]; Arrays.fill(alpha, alphaSum / numTopics); this.beta = beta; this.random = random; oneDocTopicCounts = new int[numTopics]; tokensPerTopic = new int[numTopics]; formatter = NumberFormat.getInstance(); formatter.setMaximumFractionDigits(5); System.err.println("LDA: " + numTopics + " topics"); } public Alphabet getAlphabet() { return alphabet; } public LabelAlphabet getTopicAlphabet() { return topicAlphabet; } public int getNumTopics() { return numTopics; } public ArrayList<Topication> getData() { return data; } public int getCountFeatureTopic (int featureIndex, int topicIndex) { return typeTopicCounts[featureIndex].get(topicIndex); } public int getCountTokensPerTopic (int topicIndex) { return tokensPerTopic[topicIndex]; } /** Held-out instances for empirical likelihood calculation */ public void setTestingInstances(InstanceList testing) { this.testing = testing; } public void setNumIterations (int numIterations) { this.numIterations = numIterations; } public void setBurninPeriod (int burninPeriod) { this.burninPeriod = burninPeriod; } public void setTopicDisplay(int interval, int n) { this.showTopicsInterval = interval; this.wordsPerTopic = n; } public void setRandomSeed(int seed) { random = new Randoms(seed); } public void setOptimizeInterval(int interval) { this.optimizeInterval = interval; } public void setModelOutput(int interval, String filename) { this.outputModelInterval = interval; this.outputModelFilename = filename; } /** Define how often and where to save the state * * @param interval Save a copy of the state every <code>interval</code> iterations. * @param filename Save the state to this file, with the iteration number as a suffix */ public void setSaveState(int interval, String filename) { this.saveStateInterval = interval; this.stateFilename = filename; } protected int instanceLength (Instance instance) { return ((FeatureSequence)instance.getData()).size(); } // Can be safely called multiple times. This method will complain if it can't handle the situation private void initializeForTypes (Alphabet alphabet) { if (this.alphabet == null) { this.alphabet = alphabet; this.numTypes = alphabet.size(); this.typeTopicCounts = new TIntIntHashMap[numTypes]; for (int fi = 0; fi < numTypes; fi++) typeTopicCounts[fi] = new TIntIntHashMap(); this.betaSum = beta * numTypes; } else if (alphabet != this.alphabet) { throw new IllegalArgumentException ("Cannot change Alphabet."); } else if (alphabet.size() != this.numTypes) { this.numTypes = alphabet.size(); TIntIntHashMap[] newTypeTopicCounts = new TIntIntHashMap[numTypes]; for (int i = 0; i < typeTopicCounts.length; i++) newTypeTopicCounts[i] = typeTopicCounts[i]; for (int i = typeTopicCounts.length; i < numTypes; i++) newTypeTopicCounts[i] = new TIntIntHashMap(); // TODO AKM July 18: Why wasn't the next line there previously? // this.typeTopicCounts = newTypeTopicCounts; this.betaSum = beta * numTypes; } // else, nothing changed, nothing to be done } private void initializeTypeTopicCounts () { TIntIntHashMap[] newTypeTopicCounts = new TIntIntHashMap[numTypes]; for (int i = 0; i < typeTopicCounts.length; i++) newTypeTopicCounts[i] = typeTopicCounts[i]; for (int i = typeTopicCounts.length; i < numTypes; i++) newTypeTopicCounts[i] = new TIntIntHashMap(); this.typeTopicCounts = newTypeTopicCounts; } public void addInstances (InstanceList training) { initializeForTypes (training.getDataAlphabet()); ArrayList<LabelSequence> topicSequences = new ArrayList<LabelSequence>(); for (Instance instance : training) { LabelSequence topicSequence = new LabelSequence(topicAlphabet, new int[instanceLength(instance)]); if (false) // This method not yet obeying its last "false" argument, and must be for this to work sampleTopicsForOneDoc((FeatureSequence)instance.getData(), topicSequence, false, false); else { Randoms r = new Randoms(); int[] topics = topicSequence.getFeatures(); for (int i = 0; i < topics.length; i++) topics[i] = r.nextInt(numTopics); } topicSequences.add (topicSequence); } addInstances (training, topicSequences); } public void addInstances (InstanceList training, List<LabelSequence> topics) { initializeForTypes (training.getDataAlphabet()); assert (training.size() == topics.size()); for (int i = 0; i < training.size(); i++) { Topication t = new Topication (training.get(i), this, topics.get(i)); data.add (t); // Include sufficient statistics for this one doc FeatureSequence tokenSequence = (FeatureSequence) t.instance.getData(); LabelSequence topicSequence = t.topicSequence; for (int pi = 0; pi < topicSequence.getLength(); pi++) { int topic = topicSequence.getIndexAtPosition(pi); typeTopicCounts[tokenSequence.getIndexAtPosition(pi)].adjustOrPutValue(topic, 1, 1); tokensPerTopic[topic]++; } } initializeHistogramsAndCachedValues(); } /** * Gather statistics on the size of documents * and create histograms for use in Dirichlet hyperparameter * optimization. */ protected void initializeHistogramsAndCachedValues() { int maxTokens = 0; int totalTokens = 0; int seqLen; for (int doc = 0; doc < data.size(); doc++) { FeatureSequence fs = (FeatureSequence) data.get(doc).instance.getData(); seqLen = fs.getLength(); if (seqLen > maxTokens) maxTokens = seqLen; totalTokens += seqLen; } // Initialize the smoothing-only sampling bucket smoothingOnlyMass = 0; for (int topic = 0; topic < numTopics; topic++) smoothingOnlyMass += alpha[topic] * beta / (tokensPerTopic[topic] + betaSum); // Initialize the cached coefficients, using only smoothing. cachedCoefficients = new double[ numTopics ]; for (int topic=0; topic < numTopics; topic++) cachedCoefficients[topic] = alpha[topic] / (tokensPerTopic[topic] + betaSum); System.err.println("max tokens: " + maxTokens); System.err.println("total tokens: " + totalTokens); docLengthCounts = new int[maxTokens + 1]; topicDocCounts = new int[numTopics][maxTokens + 1]; } public void estimate () throws IOException { estimate (numIterations); } public void estimate (int iterationsThisRound) throws IOException { long startTime = System.currentTimeMillis(); int maxIteration = iterationsSoFar + iterationsThisRound; for ( ; iterationsSoFar <= maxIteration; iterationsSoFar++) { long iterationStart = System.currentTimeMillis(); if (showTopicsInterval != 0 && iterationsSoFar != 0 && iterationsSoFar % showTopicsInterval == 0) { System.out.println(); printTopWords (System.out, wordsPerTopic, false); if (testing != null) { double el = empiricalLikelihood(1000, testing); double ll = modelLogLikelihood(); double mi = topicLabelMutualInformation(); System.out.println(ll + "\t" + el + "\t" + mi); } } if (saveStateInterval != 0 && iterationsSoFar % saveStateInterval == 0) { this.printState(new File(stateFilename + '.' + iterationsSoFar)); } /* if (outputModelInterval != 0 && iterations % outputModelInterval == 0) { this.write (new File(outputModelFilename+'.'+iterations)); } */ // TODO this condition should also check that we have more than one sample to work with here // (The number of samples actually obtained is not yet tracked.) if (iterationsSoFar > burninPeriod && optimizeInterval != 0 && iterationsSoFar % optimizeInterval == 0) { alphaSum = Dirichlet.learnParameters(alpha, topicDocCounts, docLengthCounts); smoothingOnlyMass = 0.0; for (int topic = 0; topic < numTopics; topic++) { smoothingOnlyMass += alpha[topic] * beta / (tokensPerTopic[topic] + betaSum); cachedCoefficients[topic] = alpha[topic] / (tokensPerTopic[topic] + betaSum); } clearHistograms(); } // Loop over every document in the corpus topicTermCount = betaTopicCount = smoothingOnlyCount = 0; int numDocs = data.size(); // TODO consider beginning by sub-sampling? for (int di = 0; di < numDocs; di++) { FeatureSequence tokenSequence = (FeatureSequence) data.get(di).instance.getData(); LabelSequence topicSequence = (LabelSequence) data.get(di).topicSequence; sampleTopicsForOneDoc (tokenSequence, topicSequence, iterationsSoFar >= burninPeriod && iterationsSoFar % saveSampleInterval == 0, true); } long elapsedMillis = System.currentTimeMillis() - iterationStart; if (elapsedMillis < 1000) { System.out.print(elapsedMillis + "ms "); } else { System.out.print((elapsedMillis/1000) + "s "); } //System.out.println(topicTermCount + "\t" + betaTopicCount + "\t" + smoothingOnlyCount); if (iterationsSoFar % 10 == 0) { System.out.println ("<" + iterationsSoFar + "> "); if (printLogLikelihood) System.out.println (modelLogLikelihood()); } System.out.flush(); } long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0); long minutes = seconds / 60; seconds %= 60; long hours = minutes / 60; minutes %= 60; long days = hours / 24; hours %= 24; System.out.print ("\nTotal time: "); if (days != 0) { System.out.print(days); System.out.print(" days "); } if (hours != 0) { System.out.print(hours); System.out.print(" hours "); } if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); } System.out.print(seconds); System.out.println(" seconds"); } private void clearHistograms() { Arrays.fill(docLengthCounts, 0); for (int topic = 0; topic < topicDocCounts.length; topic++) Arrays.fill(topicDocCounts[topic], 0); } /** If topicSequence assignments are already set and accounted for in sufficient statistics, * then readjustTopicsAndStats should be true. The topics will be re-sampled and sufficient statistics changes. * If operating on a new or a test document, and featureSequence & topicSequence are not already accounted for in the sufficient statistics, * then readjustTopicsAndStats should be false. The current topic assignments will be ignored, and the sufficient statistics * will not be changed. * If you want to estimate the Dirichlet alpha based on the per-document topic multinomials sampled this round, * then saveStateForAlphaEstimation should be true. */ private void oldSampleTopicsForOneDoc (FeatureSequence featureSequence, FeatureSequence topicSequence, boolean saveStateForAlphaEstimation, boolean readjustTopicsAndStats) { long startTime = System.currentTimeMillis(); int[] oneDocTopics = topicSequence.getFeatures(); TIntIntHashMap currentTypeTopicCounts; int type, oldTopic, newTopic; double[] topicDistribution; double topicDistributionSum; int docLen = featureSequence.getLength(); int adjustedValue; int[] topicIndices, topicCounts; double weight; // populate topic counts Arrays.fill(oneDocTopicCounts, 0); if (readjustTopicsAndStats) { for (int token = 0; token < docLen; token++) { oneDocTopicCounts[ oneDocTopics[token] ]++; } } // Iterate over the tokens (words) in the document for (int token = 0; token < docLen; token++) { type = featureSequence.getIndexAtPosition(token); oldTopic = oneDocTopics[token]; currentTypeTopicCounts = typeTopicCounts[type]; assert (currentTypeTopicCounts.size() != 0); if (readjustTopicsAndStats) { // Remove this token from all counts oneDocTopicCounts[oldTopic]--; adjustedValue = currentTypeTopicCounts.adjustOrPutValue(oldTopic, -1, -1); if (adjustedValue == 0) currentTypeTopicCounts.remove(oldTopic); else if (adjustedValue == -1) throw new IllegalStateException ("Token count in topic went negative."); tokensPerTopic[oldTopic]--; } // Build a distribution over topics for this token topicIndices = currentTypeTopicCounts.keys(); topicCounts = currentTypeTopicCounts.getValues(); topicDistribution = new double[topicIndices.length]; // TODO Yipes, memory allocation in the inner loop! But note that .keys and .getValues is doing this too. topicDistributionSum = 0; for (int i = 0; i < topicCounts.length; i++) { int topic = topicIndices[i]; weight = ((topicCounts[i] + beta) / (tokensPerTopic[topic] + betaSum)) * ((oneDocTopicCounts[topic] + alpha[topic])); topicDistributionSum += weight; topicDistribution[topic] = weight; } // Sample a topic assignment from this distribution newTopic = topicIndices[random.nextDiscrete (topicDistribution, topicDistributionSum)]; if (readjustTopicsAndStats) { // Put that new topic into the counts oneDocTopics[token] = newTopic; oneDocTopicCounts[newTopic]++; typeTopicCounts[type].adjustOrPutValue(newTopic, 1, 1); tokensPerTopic[newTopic]++; } } if (saveStateForAlphaEstimation) { // Update the document-topic count histogram, for dirichlet estimation docLengthCounts[ docLen ]++; for (int topic=0; topic < numTopics; topic++) { topicDocCounts[topic][ oneDocTopicCounts[topic] ]++; } } } protected void sampleTopicsForOneDoc (FeatureSequence tokenSequence, FeatureSequence topicSequence, boolean shouldSaveState, boolean readjustTopicsAndStats /* currently ignored */) { int[] oneDocTopics = topicSequence.getFeatures(); TIntIntHashMap currentTypeTopicCounts; int type, oldTopic, newTopic; double topicWeightsSum; int docLength = tokenSequence.getLength(); // populate topic counts TIntIntHashMap localTopicCounts = new TIntIntHashMap(); for (int position = 0; position < docLength; position++) { localTopicCounts.adjustOrPutValue(oneDocTopics[position], 1, 1); } // Initialize the topic count/beta sampling bucket double topicBetaMass = 0.0; for (int topic: localTopicCounts.keys()) { int n = localTopicCounts.get(topic); // initialize the normalization constant for the (B * n_{t|d}) term topicBetaMass += beta * n / (tokensPerTopic[topic] + betaSum); // update the coefficients for the non-zero topics cachedCoefficients[topic] = (alpha[topic] + n) / (tokensPerTopic[topic] + betaSum); } double topicTermMass = 0.0; double[] topicTermScores = new double[numTopics]; int[] topicTermIndices; int[] topicTermValues; int i; double score; // Iterate over the positions (words) in the document for (int position = 0; position < docLength; position++) { type = tokenSequence.getIndexAtPosition(position); oldTopic = oneDocTopics[position]; currentTypeTopicCounts = typeTopicCounts[type]; assert(currentTypeTopicCounts.get(oldTopic) >= 0); // Remove this token from all counts. // Note that we actually want to remove the key if it goes // to zero, not set it to 0. if (currentTypeTopicCounts.get(oldTopic) == 1) { currentTypeTopicCounts.remove(oldTopic); } else { currentTypeTopicCounts.adjustValue(oldTopic, -1); } smoothingOnlyMass -= alpha[oldTopic] * beta / (tokensPerTopic[oldTopic] + betaSum); topicBetaMass -= beta * localTopicCounts.get(oldTopic) / (tokensPerTopic[oldTopic] + betaSum); if (localTopicCounts.get(oldTopic) == 1) { localTopicCounts.remove(oldTopic); } else { localTopicCounts.adjustValue(oldTopic, -1); } tokensPerTopic[oldTopic]--; smoothingOnlyMass += alpha[oldTopic] * beta / (tokensPerTopic[oldTopic] + betaSum); topicBetaMass += beta * localTopicCounts.get(oldTopic) / (tokensPerTopic[oldTopic] + betaSum); cachedCoefficients[oldTopic] = (alpha[oldTopic] + localTopicCounts.get(oldTopic)) / (tokensPerTopic[oldTopic] + betaSum); topicTermMass = 0.0; topicTermIndices = currentTypeTopicCounts.keys(); topicTermValues = currentTypeTopicCounts.getValues(); for (i=0; i < topicTermIndices.length; i++) { int topic = topicTermIndices[i]; score = cachedCoefficients[topic] * topicTermValues[i]; // ((alpha[topic] + localTopicCounts.get(topic)) * // topicTermValues[i]) / // (tokensPerTopic[topic] + betaSum); // Note: I tried only doing this next bit if // score > 0, but it didn't make any difference, // at least in the first few iterations. topicTermMass += score; topicTermScores[i] = score; // topicTermIndices[i] = topic; } // indicate that this is the last topic // topicTermIndices[i] = -1; double sample = random.nextUniform() * (smoothingOnlyMass + topicBetaMass + topicTermMass); double origSample = sample; // Make sure it actually gets set newTopic = -1; if (sample < topicTermMass) { //topicTermCount++; i = -1; while (sample > 0) { i++; sample -= topicTermScores[i]; } newTopic = topicTermIndices[i]; } else { sample -= topicTermMass; if (sample < topicBetaMass) { //betaTopicCount++; sample /= beta; topicTermIndices = localTopicCounts.keys(); topicTermValues = localTopicCounts.getValues(); for (i=0; i < topicTermIndices.length; i++) { newTopic = topicTermIndices[i]; sample -= topicTermValues[i] / (tokensPerTopic[newTopic] + betaSum); if (sample <= 0.0) { break; } } } else { //smoothingOnlyCount++; sample -= topicBetaMass; sample /= beta; for (int topic = 0; topic < numTopics; topic++) { sample -= alpha[topic] / (tokensPerTopic[topic] + betaSum); if (sample <= 0.0) { newTopic = topic; break; } } } } if (newTopic == -1) { System.err.println("LDAHyper sampling error: "+ origSample + " " + sample + " " + smoothingOnlyMass + " " + topicBetaMass + " " + topicTermMass); newTopic = numTopics-1; // TODO is this appropriate //throw new IllegalStateException ("LDAHyper: New topic not sampled."); } //assert(newTopic != -1); // Put that new topic into the counts oneDocTopics[position] = newTopic; currentTypeTopicCounts.adjustOrPutValue(newTopic, 1, 1); smoothingOnlyMass -= alpha[newTopic] * beta / (tokensPerTopic[newTopic] + betaSum); topicBetaMass -= beta * localTopicCounts.get(newTopic) / (tokensPerTopic[newTopic] + betaSum); localTopicCounts.adjustOrPutValue(newTopic, 1, 1); tokensPerTopic[newTopic]++; // update the coefficients for the non-zero topics cachedCoefficients[newTopic] = (alpha[newTopic] + localTopicCounts.get(newTopic)) / (tokensPerTopic[newTopic] + betaSum); smoothingOnlyMass += alpha[newTopic] * beta / (tokensPerTopic[newTopic] + betaSum); topicBetaMass += beta * localTopicCounts.get(newTopic) / (tokensPerTopic[newTopic] + betaSum); assert(currentTypeTopicCounts.get(newTopic) >= 0); } // Clean up our mess: reset the coefficients to values with only // smoothing. The next doc will update its own non-zero topics... for (int topic: localTopicCounts.keys()) { cachedCoefficients[topic] = alpha[topic] / (tokensPerTopic[topic] + betaSum); } if (shouldSaveState) { // Update the document-topic count histogram, // for dirichlet estimation docLengthCounts[ docLength ]++; for (int topic: localTopicCounts.keys()) { topicDocCounts[topic][ localTopicCounts.get(topic) ]++; } } } public IDSorter[] getSortedTopicWords(int topic) { IDSorter[] sortedTypes = new IDSorter[ numTypes ]; for (int type = 0; type < numTypes; type++) sortedTypes[type] = new IDSorter(type, typeTopicCounts[type].get(topic)); Arrays.sort(sortedTypes); return sortedTypes; } public void printTopWords (File file, int numWords, boolean useNewLines) throws IOException { PrintStream out = new PrintStream (file); printTopWords(out, numWords, useNewLines); out.close(); } // TreeSet implementation is ~70x faster than RankedFeatureVector -DM public void printTopWords (PrintStream out, int numWords, boolean usingNewLines) { for (int topic = 0; topic < numTopics; topic++) { TreeSet<IDSorter> sortedWords = new TreeSet<IDSorter>(); for (int type = 0; type < numTypes; type++) { if (typeTopicCounts[type].containsKey(topic)) { sortedWords.add(new IDSorter(type, typeTopicCounts[type].get(topic))); } } if (usingNewLines) { out.println ("Topic " + topic); int word = 1; Iterator<IDSorter> iterator = sortedWords.iterator(); while (iterator.hasNext() && word < numWords) { IDSorter info = iterator.next(); out.println(alphabet.lookupObject(info.getID()) + "\t" + (int) info.getWeight()); word++; } } else { out.print (topic + "\t" + formatter.format(alpha[topic]) + "\t" + tokensPerTopic[topic] + "\t"); int word = 1; Iterator<IDSorter> iterator = sortedWords.iterator(); while (iterator.hasNext() && word < numWords) { IDSorter info = iterator.next(); out.print(alphabet.lookupObject(info.getID()) + " "); word++; } out.println(); } } } public void topicXMLReport (PrintWriter out, int numWords) { out.println("<?xml version='1.0' ?>"); out.println("<topicModel>"); for (int topic = 0; topic < numTopics; topic++) { out.println(" <topic id='" + topic + "' alpha='" + alpha[topic] + "' totalTokens='" + tokensPerTopic[topic] + "'>"); TreeSet<IDSorter> sortedWords = new TreeSet<IDSorter>(); for (int type = 0; type < numTypes; type++) { if (typeTopicCounts[type].containsKey(topic)) { sortedWords.add(new IDSorter(type, typeTopicCounts[type].get(topic))); } } int word = 1; Iterator<IDSorter> iterator = sortedWords.iterator(); while (iterator.hasNext() && word < numWords) { IDSorter info = iterator.next(); out.println(" <word rank='" + word + "'>" + alphabet.lookupObject(info.getID()) + "</word>"); word++; } out.println(" </topic>"); } out.println("</topicModel>"); } public void topicXMLReportPhrases (PrintStream out, int numWords) { int numTopics = this.getNumTopics(); gnu.trove.TObjectIntHashMap<String>[] phrases = new gnu.trove.TObjectIntHashMap[numTopics]; Alphabet alphabet = this.getAlphabet(); // Get counts of phrases for (int ti = 0; ti < numTopics; ti++) phrases[ti] = new gnu.trove.TObjectIntHashMap<String>(); for (int di = 0; di < this.getData().size(); di++) { LDAHyper.Topication t = this.getData().get(di); Instance instance = t.instance; FeatureSequence fvs = (FeatureSequence) instance.getData(); boolean withBigrams = false; if (fvs instanceof FeatureSequenceWithBigrams) withBigrams = true; int prevtopic = -1; int prevfeature = -1; int topic = -1; StringBuffer sb = null; int feature = -1; int doclen = fvs.size(); for (int pi = 0; pi < doclen; pi++) { feature = fvs.getIndexAtPosition(pi); topic = this.getData().get(di).topicSequence.getIndexAtPosition(pi); if (topic == prevtopic && (!withBigrams || ((FeatureSequenceWithBigrams)fvs).getBiIndexAtPosition(pi) != -1)) { if (sb == null) sb = new StringBuffer (alphabet.lookupObject(prevfeature).toString() + " " + alphabet.lookupObject(feature)); else { sb.append (" "); sb.append (alphabet.lookupObject(feature)); } } else if (sb != null) { String sbs = sb.toString(); //System.out.println ("phrase:"+sbs); if (phrases[prevtopic].get(sbs) == 0) phrases[prevtopic].put(sbs,0); phrases[prevtopic].increment(sbs); prevtopic = prevfeature = -1; sb = null; } else { prevtopic = topic; prevfeature = feature; } } } // phrases[] now filled with counts // Now start printing the XML out.println("<?xml version='1.0' ?>"); out.println("<topics>"); double[] probs = new double[alphabet.size()]; for (int ti = 0; ti < numTopics; ti++) { out.print(" <topic id=\"" + ti + "\" alpha=\"" + alpha[ti] + "\" totalTokens=\"" + tokensPerTopic[ti] + "\" "); // For gathering <term> and <phrase> output temporarily // so that we can get topic-title information before printing it to "out". ByteArrayOutputStream bout = new ByteArrayOutputStream(); PrintStream pout = new PrintStream (bout); // For holding candidate topic titles AugmentableFeatureVector titles = new AugmentableFeatureVector (new Alphabet()); // Print words for (int type = 0; type < alphabet.size(); type++) probs[type] = this.getCountFeatureTopic(type, ti) / (double)this.getCountTokensPerTopic(ti); RankedFeatureVector rfv = new RankedFeatureVector (alphabet, probs); for (int ri = 0; ri < numWords; ri++) { int fi = rfv.getIndexAtRank(ri); pout.println (" <term weight=\""+probs[fi]+"\" count=\""+this.getCountFeatureTopic(fi,ti)+"\">"+alphabet.lookupObject(fi)+ "</term>"); if (ri < 20) // consider top 20 individual words as candidate titles titles.add(alphabet.lookupObject(fi), this.getCountFeatureTopic(fi,ti)); } // Print phrases Object[] keys = phrases[ti].keys(); int[] values = phrases[ti].getValues(); double counts[] = new double[keys.length]; for (int i = 0; i < counts.length; i++) counts[i] = values[i]; double countssum = MatrixOps.sum (counts); Alphabet alph = new Alphabet(keys); rfv = new RankedFeatureVector (alph, counts); //out.println ("topic "+ti); int max = rfv.numLocations() < numWords ? rfv.numLocations() : numWords; //System.out.println ("topic "+ti+" numPhrases="+rfv.numLocations()); for (int ri = 0; ri < max; ri++) { int fi = rfv.getIndexAtRank(ri); pout.println (" <phrase weight=\""+counts[fi]/countssum+"\" count=\""+values[fi]+"\">"+alph.lookupObject(fi)+ "</phrase>"); // Any phrase count less than 20 is simply unreliable if (ri < 20 && values[fi] > 20) titles.add(alph.lookupObject(fi), 100*values[fi]); // prefer phrases with a factor of 100 } // Select candidate titles StringBuffer titlesStringBuffer = new StringBuffer(); rfv = new RankedFeatureVector (titles.getAlphabet(), titles); int numTitles = 10; for (int ri = 0; ri < numTitles && ri < rfv.numLocations(); ri++) { // Don't add redundant titles if (titlesStringBuffer.indexOf(rfv.getObjectAtRank(ri).toString()) == -1) { titlesStringBuffer.append (rfv.getObjectAtRank(ri)); if (ri < numTitles-1) titlesStringBuffer.append (", "); } else numTitles++; } out.println("titles=\"" + titlesStringBuffer.toString() + "\">"); out.print(pout.toString()); out.println(" </topic>"); } out.println("</topics>"); } public void printDocumentTopics (File f) throws IOException { printDocumentTopics (new PrintWriter (new FileWriter (f) ) ); } public void printDocumentTopics (PrintWriter pw) { printDocumentTopics (pw, 0.0, -1); } /** * @param pw A print writer * @param threshold Only print topics with proportion greater than this number * @param max Print no more than this many topics */ public void printDocumentTopics (PrintWriter pw, double threshold, int max) { pw.print ("#doc source topic proportion ...\n"); int docLen; int[] topicCounts = new int[ numTopics ]; IDSorter[] sortedTopics = new IDSorter[ numTopics ]; for (int topic = 0; topic < numTopics; topic++) { // Initialize the sorters with dummy values sortedTopics[topic] = new IDSorter(topic, topic); } if (max < 0 || max > numTopics) { max = numTopics; } for (int di = 0; di < data.size(); di++) { LabelSequence topicSequence = (LabelSequence) data.get(di).topicSequence; int[] currentDocTopics = topicSequence.getFeatures(); pw.print (di); pw.print (' '); if (data.get(di).instance.getSource() != null) { pw.print (data.get(di).instance.getSource()); } else { pw.print ("null-source"); } pw.print (' '); docLen = currentDocTopics.length; // Count up the tokens for (int token=0; token < docLen; token++) { topicCounts[ currentDocTopics[token] ]++; } // And normalize for (int topic = 0; topic < numTopics; topic++) { sortedTopics[topic].set(topic, (float) topicCounts[topic] / docLen); } Arrays.sort(sortedTopics); for (int i = 0; i < max; i++) { if (sortedTopics[i].getWeight() < threshold) { break; } pw.print (sortedTopics[i].getID() + " " + sortedTopics[i].getWeight() + " "); } pw.print (" \n"); Arrays.fill(topicCounts, 0); } } public void printState (File f) throws IOException { PrintStream out = new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f)))); printState(out); out.close(); } public void printState (PrintStream out) { out.println ("#doc source pos typeindex type topic"); for (int di = 0; di < data.size(); di++) { FeatureSequence tokenSequence = (FeatureSequence) data.get(di).instance.getData(); LabelSequence topicSequence = (LabelSequence) data.get(di).topicSequence; String source = "NA"; if (data.get(di).instance.getSource() != null) { source = data.get(di).instance.getSource().toString(); } for (int pi = 0; pi < topicSequence.getLength(); pi++) { int type = tokenSequence.getIndexAtPosition(pi); int topic = topicSequence.getIndexAtPosition(pi); out.print(di); out.print(' '); out.print(source); out.print(' '); out.print(pi); out.print(' '); out.print(type); out.print(' '); out.print(alphabet.lookupObject(type)); out.print(' '); out.print(topic); out.println(); } } } // Turbo topics /* private class CorpusWordCounts { Alphabet unigramAlphabet; FeatureCounter unigramCounts = new FeatureCounter(unigramAlphabet); public CorpusWordCounts(Alphabet alphabet) { unigramAlphabet = alphabet; } private double mylog(double x) { return (x == 0) ? -1000000.0 : Math.log(x); } // The likelihood ratio significance test private double significanceTest(int thisUnigramCount, int nextUnigramCount, int nextBigramCount, int nextTotalCount, int minCount) { if (nextBigramCount < minCount) return -1.0; assert(nextUnigramCount >= nextBigramCount); double log_pi_vu = mylog(nextBigramCount) - mylog(thisUnigramCount); double log_pi_vnu = mylog(nextUnigramCount - nextBigramCount) - mylog(nextTotalCount - nextBigramCount); double log_pi_v_old = mylog(nextUnigramCount) - mylog(nextTotalCount); double log_1mp_v = mylog(1 - Math.exp(log_pi_vnu)); double log_1mp_vu = mylog(1 - Math.exp(log_pi_vu)); return 2 * (nextBigramCount * log_pi_vu + (nextUnigramCount - nextBigramCount) * log_pi_vnu - nextUnigramCount * log_pi_v_old + (thisUnigramCount- nextBigramCount) * (log_1mp_vu - log_1mp_v)); } public int[] significantBigrams(int word) { } } */ public void write (File f) { try { ObjectOutputStream oos = new ObjectOutputStream (new FileOutputStream(f)); oos.writeObject(this); oos.close(); } catch (IOException e) { System.err.println("LDAHyper.write: Exception writing LDAHyper to file " + f + ": " + e); } } public static LDAHyper read (File f) { LDAHyper lda = null; try { ObjectInputStream ois = new ObjectInputStream (new FileInputStream(f)); lda = (LDAHyper) ois.readObject(); lda.initializeTypeTopicCounts(); // To work around a bug in Trove? ois.close(); } catch (IOException e) { System.err.println("Exception reading file " + f + ": " + e); } catch (ClassNotFoundException e) { System.err.println("Exception reading file " + f + ": " + e); } return lda; } // Serialization private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private static final int NULL_INTEGER = -1; private void writeObject (ObjectOutputStream out) throws IOException { out.writeInt (CURRENT_SERIAL_VERSION); // Instance lists out.writeObject (data); out.writeObject (alphabet); out.writeObject (topicAlphabet); out.writeInt (numTopics); out.writeObject (alpha); out.writeDouble (beta); out.writeDouble (betaSum); out.writeDouble(smoothingOnlyMass); out.writeObject(cachedCoefficients); out.writeInt(iterationsSoFar); out.writeInt(numIterations); out.writeInt(burninPeriod); out.writeInt(saveSampleInterval); out.writeInt(optimizeInterval); out.writeInt(showTopicsInterval); out.writeInt(wordsPerTopic); out.writeInt(outputModelInterval); out.writeObject(outputModelFilename); out.writeInt(saveStateInterval); out.writeObject(stateFilename); out.writeObject(random); out.writeObject(formatter); out.writeBoolean(printLogLikelihood); out.writeObject(docLengthCounts); out.writeObject(topicDocCounts); for (int fi = 0; fi < numTypes; fi++) out.writeObject (typeTopicCounts[fi]); for (int ti = 0; ti < numTopics; ti++) out.writeInt (tokensPerTopic[ti]); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int featuresLength; int version = in.readInt (); data = (ArrayList<Topication>) in.readObject (); alphabet = (Alphabet) in.readObject(); topicAlphabet = (LabelAlphabet) in.readObject(); numTopics = in.readInt(); alpha = (double[]) in.readObject(); beta = in.readDouble(); betaSum = in.readDouble(); smoothingOnlyMass = in.readDouble(); cachedCoefficients = (double[]) in.readObject(); iterationsSoFar = in.readInt(); numIterations = in.readInt(); burninPeriod = in.readInt(); saveSampleInterval = in.readInt(); optimizeInterval = in.readInt(); showTopicsInterval = in.readInt(); wordsPerTopic = in.readInt(); outputModelInterval = in.readInt(); outputModelFilename = (String) in.readObject(); saveStateInterval = in.readInt(); stateFilename = (String) in.readObject(); random = (Randoms) in.readObject(); formatter = (NumberFormat) in.readObject(); printLogLikelihood = in.readBoolean(); docLengthCounts = (int[]) in.readObject(); topicDocCounts = (int[][]) in.readObject(); int numDocs = data.size(); this.numTypes = alphabet.size(); typeTopicCounts = new TIntIntHashMap[numTypes]; for (int fi = 0; fi < numTypes; fi++) typeTopicCounts[fi] = (TIntIntHashMap) in.readObject(); tokensPerTopic = new int[numTopics]; for (int ti = 0; ti < numTopics; ti++) tokensPerTopic[ti] = in.readInt(); } public double topicLabelMutualInformation() { int doc, level, label, topic, token, type; int[] docTopics; if (data.get(0).instance.getTargetAlphabet() == null) { return 0.0; } int targetAlphabetSize = data.get(0).instance.getTargetAlphabet().size(); int[][] topicLabelCounts = new int[ numTopics ][ targetAlphabetSize ]; int[] topicCounts = new int[ numTopics ]; int[] labelCounts = new int[ targetAlphabetSize ]; int total = 0; for (doc=0; doc < data.size(); doc++) { label = data.get(doc).instance.getLabeling().getBestIndex(); LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequence; docTopics = topicSequence.getFeatures(); for (token = 0; token < docTopics.length; token++) { topic = docTopics[token]; topicLabelCounts[ topic ][ label ]++; topicCounts[topic]++; labelCounts[label]++; total++; } } /* // This block will print out the best topics for each label IDSorter[] wp = new IDSorter[numTypes]; for (topic = 0; topic < numTopics; topic++) { for (type = 0; type < numTypes; type++) { wp[type] = new IDSorter (type, (((double) typeTopicCounts[type][topic]) / tokensPerTopic[topic])); } Arrays.sort (wp); StringBuffer terms = new StringBuffer(); for (int i = 0; i < 8; i++) { terms.append(instances.getDataAlphabet().lookupObject(wp[i].id)); terms.append(" "); } System.out.println(terms); for (label = 0; label < topicLabelCounts[topic].length; label++) { System.out.println(topicLabelCounts[ topic ][ label ] + "\t" + instances.getTargetAlphabet().lookupObject(label)); } System.out.println(); } */ double topicEntropy = 0.0; double labelEntropy = 0.0; double jointEntropy = 0.0; double p; double log2 = Math.log(2); for (topic = 0; topic < topicCounts.length; topic++) { if (topicCounts[topic] == 0) { continue; } p = (double) topicCounts[topic] / total; topicEntropy -= p * Math.log(p) / log2; } for (label = 0; label < labelCounts.length; label++) { if (labelCounts[label] == 0) { continue; } p = (double) labelCounts[label] / total; labelEntropy -= p * Math.log(p) / log2; } for (topic = 0; topic < topicCounts.length; topic++) { for (label = 0; label < labelCounts.length; label++) { if (topicLabelCounts[ topic ][ label ] == 0) { continue; } p = (double) topicLabelCounts[ topic ][ label ] / total; jointEntropy -= p * Math.log(p) / log2; } } return topicEntropy + labelEntropy - jointEntropy; } public double empiricalLikelihood(int numSamples, InstanceList testing) { double[][] likelihoods = new double[ testing.size() ][ numSamples ]; double[] multinomial = new double[numTypes]; double[] topicDistribution, currentSample, currentWeights; Dirichlet topicPrior = new Dirichlet(alpha); int sample, doc, topic, type, token, seqLen; FeatureSequence fs; for (sample = 0; sample < numSamples; sample++) { topicDistribution = topicPrior.nextDistribution(); Arrays.fill(multinomial, 0.0); for (topic = 0; topic < numTopics; topic++) { for (type=0; type<numTypes; type++) { multinomial[type] += topicDistribution[topic] * (beta + typeTopicCounts[type].get(topic)) / (betaSum + tokensPerTopic[topic]); } } // Convert to log probabilities for (type=0; type<numTypes; type++) { assert(multinomial[type] > 0.0); multinomial[type] = Math.log(multinomial[type]); } for (doc=0; doc<testing.size(); doc++) { fs = (FeatureSequence) testing.get(doc).getData(); seqLen = fs.getLength(); for (token = 0; token < seqLen; token++) { type = fs.getIndexAtPosition(token); // Adding this check since testing instances may // have types not found in training instances, // as pointed out by Steven Bethard. if (type < numTypes) { likelihoods[doc][sample] += multinomial[type]; } } } } double averageLogLikelihood = 0.0; double logNumSamples = Math.log(numSamples); for (doc=0; doc<testing.size(); doc++) { double max = Double.NEGATIVE_INFINITY; for (sample = 0; sample < numSamples; sample++) { if (likelihoods[doc][sample] > max) { max = likelihoods[doc][sample]; } } double sum = 0.0; for (sample = 0; sample < numSamples; sample++) { sum += Math.exp(likelihoods[doc][sample] - max); } averageLogLikelihood += Math.log(sum) + max - logNumSamples; } return averageLogLikelihood; } public double modelLogLikelihood() { double logLikelihood = 0.0; int nonZeroTopics; // The likelihood of the model is a combination of a // Dirichlet-multinomial for the words in each topic // and a Dirichlet-multinomial for the topics in each // document. // The likelihood function of a dirichlet multinomial is // Gamma( sum_i alpha_i ) prod_i Gamma( alpha_i + N_i ) // prod_i Gamma( alpha_i ) Gamma( sum_i (alpha_i + N_i) ) // So the log likelihood is // logGamma ( sum_i alpha_i ) - logGamma ( sum_i (alpha_i + N_i) ) + // sum_i [ logGamma( alpha_i + N_i) - logGamma( alpha_i ) ] // Do the documents first int[] topicCounts = new int[numTopics]; double[] topicLogGammas = new double[numTopics]; int[] docTopics; for (int topic=0; topic < numTopics; topic++) { topicLogGammas[ topic ] = Dirichlet.logGammaStirling( alpha[topic] ); } for (int doc=0; doc < data.size(); doc++) { LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequence; docTopics = topicSequence.getFeatures(); for (int token=0; token < docTopics.length; token++) { topicCounts[ docTopics[token] ]++; } for (int topic=0; topic < numTopics; topic++) { if (topicCounts[topic] > 0) { logLikelihood += (Dirichlet.logGammaStirling(alpha[topic] + topicCounts[topic]) - topicLogGammas[ topic ]); } } // subtract the (count + parameter) sum term logLikelihood -= Dirichlet.logGammaStirling(alphaSum + docTopics.length); Arrays.fill(topicCounts, 0); } // add the parameter sum term logLikelihood += data.size() * Dirichlet.logGammaStirling(alphaSum); // And the topics // Count the number of type-topic pairs int nonZeroTypeTopics = 0; for (int type=0; type < numTypes; type++) { int[] usedTopics = typeTopicCounts[type].keys(); for (int topic : usedTopics) { int count = typeTopicCounts[type].get(topic); if (count > 0) { nonZeroTypeTopics++; logLikelihood += Dirichlet.logGammaStirling(beta + count); } } } for (int topic=0; topic < numTopics; topic++) { logLikelihood -= Dirichlet.logGammaStirling( (beta * numTopics) + tokensPerTopic[ topic ] ); } logLikelihood += (Dirichlet.logGammaStirling(beta * numTopics)) - (Dirichlet.logGammaStirling(beta) * nonZeroTypeTopics); return logLikelihood; } // Recommended to use mallet/bin/vectors2topics instead. public static void main (String[] args) throws IOException { InstanceList training = InstanceList.load (new File(args[0])); int numTopics = args.length > 1 ? Integer.parseInt(args[1]) : 200; InstanceList testing = args.length > 2 ? InstanceList.load (new File(args[2])) : null; LDAHyper lda = new LDAHyper (numTopics, 50.0, 0.01); lda.printLogLikelihood = true; lda.setTopicDisplay(50,7); lda.addInstances(training); lda.estimate(); } }