/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.topics; import java.util.*; import java.util.zip.*; import java.io.*; import java.text.NumberFormat; import cc.mallet.types.*; import cc.mallet.util.CommandOption; import cc.mallet.util.Randoms; /** * Latent Dirichlet Allocation for loosely parallel corpora in arbitrary languages * * @author David Mimno, Andrew McCallum */ public class PolylingualTopicModel implements Serializable { static CommandOption.SpacedStrings languageInputFiles = new CommandOption.SpacedStrings (PolylingualTopicModel.class, "language-inputs", "FILENAME [FILENAME ...]", true, null, "Filenames for polylingual topic model. Each language should have its own file, " + "with the same number of instances in each file. If a document is missing in " + "one language, there should be an empty instance.", null); static CommandOption.String outputModelFilename = new CommandOption.String (PolylingualTopicModel.class, "output-model", "FILENAME", true, null, "The filename in which to write the binary topic model at the end of the iterations. " + "By default this is null, indicating that no file will be written.", null); static CommandOption.String inputModelFilename = new CommandOption.String (PolylingualTopicModel.class, "input-model", "FILENAME", true, null, "The filename from which to read the binary topic model to which the --input will be appended, " + "allowing incremental training. " + "By default this is null, indicating that no file will be read.", null); static CommandOption.String inferencerFilename = new CommandOption.String (PolylingualTopicModel.class, "inferencer-filename", "FILENAME", true, null, "A topic inferencer applies a previously trained topic model to new documents. " + "By default this is null, indicating that no file will be written.", null); static CommandOption.String evaluatorFilename = new CommandOption.String (PolylingualTopicModel.class, "evaluator-filename", "FILENAME", true, null, "A held-out likelihood evaluator for new documents. " + "By default this is null, indicating that no file will be written.", null); static CommandOption.String stateFile = new CommandOption.String (PolylingualTopicModel.class, "output-state", "FILENAME", true, null, "The filename in which to write the Gibbs sampling state after at the end of the iterations. " + "By default this is null, indicating that no file will be written.", null); static CommandOption.String topicKeysFile = new CommandOption.String (PolylingualTopicModel.class, "output-topic-keys", "FILENAME", true, null, "The filename in which to write the top words for each topic and any Dirichlet parameters. " + "By default this is null, indicating that no file will be written.", null); static CommandOption.String docTopicsFile = new CommandOption.String (PolylingualTopicModel.class, "output-doc-topics", "FILENAME", true, null, "The filename in which to write the topic proportions per document, at the end of the iterations. " + "By default this is null, indicating that no file will be written.", null); static CommandOption.Double docTopicsThreshold = new CommandOption.Double (PolylingualTopicModel.class, "doc-topics-threshold", "DECIMAL", true, 0.0, "When writing topic proportions per document with --output-doc-topics, " + "do not print topics with proportions less than this threshold value.", null); static CommandOption.Integer docTopicsMax = new CommandOption.Integer (PolylingualTopicModel.class, "doc-topics-max", "INTEGER", true, -1, "When writing topic proportions per document with --output-doc-topics, " + "do not print more than INTEGER number of topics. "+ "A negative value indicates that all topics should be printed.", null); static CommandOption.Integer outputModelIntervalOption = new CommandOption.Integer (PolylingualTopicModel.class, "output-model-interval", "INTEGER", true, 0, "The number of iterations between writing the model (and its Gibbs sampling state) to a binary file. " + "You must also set the --output-model to use this option, whose argument will be the prefix of the filenames.", null); static CommandOption.Integer outputStateIntervalOption = new CommandOption.Integer (PolylingualTopicModel.class, "output-state-interval", "INTEGER", true, 0, "The number of iterations between writing the sampling state to a text file. " + "You must also set the --output-state to use this option, whose argument will be the prefix of the filenames.", null); static CommandOption.Integer numTopicsOption = new CommandOption.Integer (PolylingualTopicModel.class, "num-topics", "INTEGER", true, 10, "The number of topics to fit.", null); static CommandOption.Integer numIterationsOption = new CommandOption.Integer (PolylingualTopicModel.class, "num-iterations", "INTEGER", true, 1000, "The number of iterations of Gibbs sampling.", null); static CommandOption.Integer randomSeedOption = new CommandOption.Integer (PolylingualTopicModel.class, "random-seed", "INTEGER", true, 0, "The random seed for the Gibbs sampler. Default is 0, which will use the clock.", null); static CommandOption.Integer topWordsOption = new CommandOption.Integer (PolylingualTopicModel.class, "num-top-words", "INTEGER", true, 20, "The number of most probable words to print for each topic after model estimation.", null); static CommandOption.Integer showTopicsIntervalOption = new CommandOption.Integer (PolylingualTopicModel.class, "show-topics-interval", "INTEGER", true, 50, "The number of iterations between printing a brief summary of the topics so far.", null); static CommandOption.Integer optimizeIntervalOption = new CommandOption.Integer (PolylingualTopicModel.class, "optimize-interval", "INTEGER", true, 0, "The number of iterations between reestimating dirichlet hyperparameters.", null); static CommandOption.Integer optimizeBurnInOption = new CommandOption.Integer (PolylingualTopicModel.class, "optimize-burn-in", "INTEGER", true, 200, "The number of iterations to run before first estimating dirichlet hyperparameters.", null); static CommandOption.Double alphaOption = new CommandOption.Double (PolylingualTopicModel.class, "alpha", "DECIMAL", true, 50.0, "Alpha parameter: smoothing over topic distribution.",null); static CommandOption.Double betaOption = new CommandOption.Double (PolylingualTopicModel.class, "beta", "DECIMAL", true, 0.01, "Beta parameter: smoothing over unigram distribution.",null); public class TopicAssignment implements Serializable { public Instance[] instances; public LabelSequence[] topicSequences; public Labeling topicDistribution; public TopicAssignment (Instance[] instances, LabelSequence[] topicSequences) { this.instances = instances; this.topicSequences = topicSequences; } } int numLanguages = 1; protected ArrayList<TopicAssignment> data; // the training instances and their topic assignments protected LabelAlphabet topicAlphabet; // the alphabet for the topics protected int numStopwords = 0; protected int numTopics; // Number of topics to be fit HashSet<String> testingIDs = null; // These values are used to encode type/topic counts as // count/topic pairs in a single int. protected int topicMask; protected int topicBits; protected Alphabet[] alphabets; protected int[] vocabularySizes; protected double[] alpha; // Dirichlet(alpha,alpha,...) is the distribution over topics protected double alphaSum; protected double[] betas; // Prior on per-topic multinomial distribution over words protected double[] betaSums; protected int[] languageMaxTypeCounts; public static final double DEFAULT_BETA = 0.01; protected double[] languageSmoothingOnlyMasses; protected double[][] languageCachedCoefficients; int topicTermCount = 0; int betaTopicCount = 0; int smoothingOnlyCount = 0; // An array to put the topic counts for the current document. // Initialized locally below. Defined here to avoid // garbage collection overhead. protected int[] oneDocTopicCounts; // indexed by <document index, topic index> protected int[][][] languageTypeTopicCounts; // indexed by <feature index, topic index> protected int[][] languageTokensPerTopic; // indexed by <topic index> // for dirichlet estimation protected int[] docLengthCounts; // histogram of document sizes, summed over languages protected int[][] topicDocCounts; // histogram of document/topic counts, indexed by <topic index, sequence position index> protected int iterationsSoFar = 1; public int numIterations = 1000; public int burninPeriod = 5; public int saveSampleInterval = 5; // was 10; public int optimizeInterval = 10; public int showTopicsInterval = 10; // was 50; public int wordsPerTopic = 7; protected int saveModelInterval = 0; protected String modelFilename; protected int saveStateInterval = 0; protected String stateFilename = null; protected Randoms random; protected NumberFormat formatter; protected boolean printLogLikelihood = false; public PolylingualTopicModel (int numberOfTopics) { this (numberOfTopics, numberOfTopics); } public PolylingualTopicModel (int numberOfTopics, double alphaSum) { this (numberOfTopics, alphaSum, new Randoms()); } private static LabelAlphabet newLabelAlphabet (int numTopics) { LabelAlphabet ret = new LabelAlphabet(); for (int i = 0; i < numTopics; i++) ret.lookupIndex("topic"+i); return ret; } public PolylingualTopicModel (int numberOfTopics, double alphaSum, Randoms random) { this (newLabelAlphabet (numberOfTopics), alphaSum, random); } public PolylingualTopicModel (LabelAlphabet topicAlphabet, double alphaSum, Randoms random) { this.data = new ArrayList<TopicAssignment>(); this.topicAlphabet = topicAlphabet; this.numTopics = topicAlphabet.size(); if (Integer.bitCount(numTopics) == 1) { // exact power of 2 topicMask = numTopics - 1; topicBits = Integer.bitCount(topicMask); } else { // otherwise add an extra bit topicMask = Integer.highestOneBit(numTopics) * 2 - 1; topicBits = Integer.bitCount(topicMask); } this.alphaSum = alphaSum; this.alpha = new double[numTopics]; Arrays.fill(alpha, alphaSum / numTopics); this.random = random; formatter = NumberFormat.getInstance(); formatter.setMaximumFractionDigits(5); System.err.println("Polylingual LDA: " + numTopics + " topics, " + topicBits + " topic bits, " + Integer.toBinaryString(topicMask) + " topic mask"); } public void loadTestingIDs(File testingIDFile) throws IOException { testingIDs = new HashSet(); BufferedReader in = new BufferedReader(new FileReader(testingIDFile)); String id = null; while ((id = in.readLine()) != null) { testingIDs.add(id); } in.close(); } public LabelAlphabet getTopicAlphabet() { return topicAlphabet; } public int getNumTopics() { return numTopics; } public ArrayList<TopicAssignment> getData() { return data; } public void setNumIterations (int numIterations) { this.numIterations = numIterations; } public void setBurninPeriod (int burninPeriod) { this.burninPeriod = burninPeriod; } public void setTopicDisplay(int interval, int n) { this.showTopicsInterval = interval; this.wordsPerTopic = n; } public void setRandomSeed(int seed) { random = new Randoms(seed); } public void setOptimizeInterval(int interval) { this.optimizeInterval = interval; } public void setModelOutput(int interval, String filename) { this.saveModelInterval = interval; this.modelFilename = filename; } /** Define how often and where to save the state * * @param interval Save a copy of the state every <code>interval</code> iterations. * @param filename Save the state to this file, with the iteration number as a suffix */ public void setSaveState(int interval, String filename) { this.saveStateInterval = interval; this.stateFilename = filename; } public void addInstances (InstanceList[] training) { numLanguages = training.length; languageTokensPerTopic = new int[numLanguages][numTopics]; alphabets = new Alphabet[ numLanguages ]; vocabularySizes = new int[ numLanguages ]; betas = new double[ numLanguages ]; betaSums = new double[ numLanguages ]; languageMaxTypeCounts = new int[ numLanguages ]; languageTypeTopicCounts = new int[ numLanguages ][][]; int numInstances = training[0].size(); HashSet[] stoplists = new HashSet[ numLanguages ]; for (int language = 0; language < numLanguages; language++) { if (training[language].size() != numInstances) { System.err.println("Warning: language " + language + " has " + training[language].size() + " instances, lang 0 has " + numInstances); } alphabets[ language ] = training[ language ].getDataAlphabet(); vocabularySizes[ language ] = alphabets[ language ].size(); betas[language] = DEFAULT_BETA; betaSums[language] = betas[language] * vocabularySizes[ language ]; languageTypeTopicCounts[language] = new int[ vocabularySizes[language] ][]; int[][] typeTopicCounts = languageTypeTopicCounts[language]; // Get the total number of occurrences of each word type int[] typeTotals = new int[ vocabularySizes[language] ]; for (Instance instance : training[language]) { if (testingIDs != null && testingIDs.contains(instance.getName())) { continue; } FeatureSequence tokens = (FeatureSequence) instance.getData(); for (int position = 0; position < tokens.getLength(); position++) { int type = tokens.getIndexAtPosition(position); typeTotals[ type ]++; } } /* Automatic stoplist creation, currently disabled TreeSet<IDSorter> sortedWords = new TreeSet<IDSorter>(); for (int type = 0; type < vocabularySizes[language]; type++) { sortedWords.add(new IDSorter(type, typeTotals[type])); } stoplists[language] = new HashSet<Integer>(); Iterator<IDSorter> typeIterator = sortedWords.iterator(); int totalStopwords = 0; while (typeIterator.hasNext() && totalStopwords < numStopwords) { stoplists[language].add(typeIterator.next().getID()); } */ // Allocate enough space so that we never have to worry about // overflows: either the number of topics or the number of times // the type occurs. for (int type = 0; type < vocabularySizes[language]; type++) { if (typeTotals[type] > languageMaxTypeCounts[language]) { languageMaxTypeCounts[language] = typeTotals[type]; } typeTopicCounts[type] = new int[ Math.min(numTopics, typeTotals[type]) ]; } } for (int doc = 0; doc < numInstances; doc++) { if (testingIDs != null && testingIDs.contains(training[0].get(doc).getName())) { continue; } Instance[] instances = new Instance[ numLanguages ]; LabelSequence[] topicSequences = new LabelSequence[ numLanguages ]; for (int language = 0; language < numLanguages; language++) { int[][] typeTopicCounts = languageTypeTopicCounts[language]; int[] tokensPerTopic = languageTokensPerTopic[language]; instances[language] = training[language].get(doc); FeatureSequence tokens = (FeatureSequence) instances[language].getData(); topicSequences[language] = new LabelSequence(topicAlphabet, new int[ tokens.size() ]); int[] topics = topicSequences[language].getFeatures(); for (int position = 0; position < tokens.size(); position++) { int type = tokens.getIndexAtPosition(position); int[] currentTypeTopicCounts = typeTopicCounts[ type ]; int topic = random.nextInt(numTopics); // If the word is one of the [numStopwords] most // frequent words, put it in a non-sampled topic. //if (stoplists[language].contains(type)) { // topic = -1; //} topics[position] = topic; tokensPerTopic[topic]++; // The format for these arrays is // the topic in the rightmost bits // the count in the remaining (left) bits. // Since the count is in the high bits, sorting (desc) // by the numeric value of the int guarantees that // higher counts will be before the lower counts. // Start by assuming that the array is either empty // or is in sorted (descending) order. // Here we are only adding counts, so if we find // an existing location with the topic, we only need // to ensure that it is not larger than its left neighbor. int index = 0; int currentTopic = currentTypeTopicCounts[index] & topicMask; int currentValue; while (currentTypeTopicCounts[index] > 0 && currentTopic != topic) { index++; /* // Debugging output... if (index >= currentTypeTopicCounts.length) { for (int i=0; i < currentTypeTopicCounts.length; i++) { System.out.println((currentTypeTopicCounts[i] & topicMask) + ":" + (currentTypeTopicCounts[i] >> topicBits) + " "); } System.out.println(type + " " + typeTotals[type]); } */ currentTopic = currentTypeTopicCounts[index] & topicMask; } currentValue = currentTypeTopicCounts[index] >> topicBits; if (currentValue == 0) { // new value is 1, so we don't have to worry about sorting // (except by topic suffix, which doesn't matter) currentTypeTopicCounts[index] = (1 << topicBits) + topic; } else { currentTypeTopicCounts[index] = ((currentValue + 1) << topicBits) + topic; // Now ensure that the array is still sorted by // bubbling this value up. while (index > 0 && currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) { int temp = currentTypeTopicCounts[index]; currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1]; currentTypeTopicCounts[index - 1] = temp; index--; } } } } TopicAssignment t = new TopicAssignment (instances, topicSequences); data.add (t); } initializeHistograms(); languageSmoothingOnlyMasses = new double[ numLanguages ]; languageCachedCoefficients = new double[ numLanguages ][ numTopics ]; cacheValues(); } /** * Gather statistics on the size of documents * and create histograms for use in Dirichlet hyperparameter * optimization. */ private void initializeHistograms() { int maxTokens = 0; int totalTokens = 0; for (int doc = 0; doc < data.size(); doc++) { int length = 0; for (LabelSequence sequence : data.get(doc).topicSequences) { length += sequence.getLength(); } if (length > maxTokens) { maxTokens = length; } totalTokens += length; } System.err.println("max tokens: " + maxTokens); System.err.println("total tokens: " + totalTokens); docLengthCounts = new int[maxTokens + 1]; topicDocCounts = new int[numTopics][maxTokens + 1]; } private void cacheValues() { for (int language = 0; language < numLanguages; language++) { languageSmoothingOnlyMasses[language] = 0.0; for (int topic=0; topic < numTopics; topic++) { languageSmoothingOnlyMasses[language] += alpha[topic] * betas[language] / (languageTokensPerTopic[language][topic] + betaSums[language]); languageCachedCoefficients[language][topic] = alpha[topic] / (languageTokensPerTopic[language][topic] + betaSums[language]); } } } private void clearHistograms() { Arrays.fill(docLengthCounts, 0); for (int topic = 0; topic < topicDocCounts.length; topic++) Arrays.fill(topicDocCounts[topic], 0); } public void estimate () throws IOException { estimate (numIterations); } public void estimate (int iterationsThisRound) throws IOException { long startTime = System.currentTimeMillis(); int maxIteration = iterationsSoFar + iterationsThisRound; long totalTime = 0; for ( ; iterationsSoFar <= maxIteration; iterationsSoFar++) { long iterationStart = System.currentTimeMillis(); if (showTopicsInterval != 0 && iterationsSoFar != 0 && iterationsSoFar % showTopicsInterval == 0) { System.out.println(); printTopWords (System.out, wordsPerTopic, false); } if (saveStateInterval != 0 && iterationsSoFar % saveStateInterval == 0) { this.printState(new File(stateFilename + '.' + iterationsSoFar)); } /* if (saveModelInterval != 0 && iterations % saveModelInterval == 0) { this.write (new File(modelFilename+'.'+iterations)); } */ // TODO this condition should also check that we have more than one sample to work with here // (The number of samples actually obtained is not yet tracked.) if (iterationsSoFar > burninPeriod && optimizeInterval != 0 && iterationsSoFar % optimizeInterval == 0) { alphaSum = Dirichlet.learnParameters(alpha, topicDocCounts, docLengthCounts); optimizeBetas(); clearHistograms(); cacheValues(); } // Loop over every document in the corpus topicTermCount = betaTopicCount = smoothingOnlyCount = 0; for (int doc = 0; doc < data.size(); doc++) { sampleTopicsForOneDoc (data.get(doc), (iterationsSoFar >= burninPeriod && iterationsSoFar % saveSampleInterval == 0)); } long elapsedMillis = System.currentTimeMillis() - iterationStart; totalTime += elapsedMillis; if ((iterationsSoFar + 1) % 10 == 0) { double ll = modelLogLikelihood(); System.out.println(elapsedMillis + "\t" + totalTime + "\t" + ll); } else { System.out.print(elapsedMillis + " "); } } /* long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0); long minutes = seconds / 60; seconds %= 60; long hours = minutes / 60; minutes %= 60; long days = hours / 24; hours %= 24; System.out.print ("\nTotal time: "); if (days != 0) { System.out.print(days); System.out.print(" days "); } if (hours != 0) { System.out.print(hours); System.out.print(" hours "); } if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); } System.out.print(seconds); System.out.println(" seconds"); */ } public void optimizeBetas() { for (int language = 0; language < numLanguages; language++) { // The histogram starts at count 0, so if all of the // tokens of the most frequent type were assigned to one topic, // we would need to store a maxTypeCount + 1 count. int[] countHistogram = new int[languageMaxTypeCounts[language] + 1]; // Now count the number of type/topic pairs that have // each number of tokens. int[][] typeTopicCounts = languageTypeTopicCounts[language]; int[] tokensPerTopic = languageTokensPerTopic[language]; int index; for (int type = 0; type < vocabularySizes[language]; type++) { int[] counts = typeTopicCounts[type]; index = 0; while (index < counts.length && counts[index] > 0) { int count = counts[index] >> topicBits; countHistogram[count]++; index++; } } // Figure out how large we need to make the "observation lengths" // histogram. int maxTopicSize = 0; for (int topic = 0; topic < numTopics; topic++) { if (tokensPerTopic[topic] > maxTopicSize) { maxTopicSize = tokensPerTopic[topic]; } } // Now allocate it and populate it. int[] topicSizeHistogram = new int[maxTopicSize + 1]; for (int topic = 0; topic < numTopics; topic++) { topicSizeHistogram[ tokensPerTopic[topic] ]++; } betaSums[language] = Dirichlet.learnSymmetricConcentration(countHistogram, topicSizeHistogram, vocabularySizes[ language ], betaSums[language]); betas[language] = betaSums[language] / vocabularySizes[ language ]; } } protected void sampleTopicsForOneDoc (TopicAssignment topicAssignment, boolean shouldSaveState) { int[] currentTypeTopicCounts; int type, oldTopic, newTopic; double topicWeightsSum; int[] localTopicCounts = new int[numTopics]; int[] localTopicIndex = new int[numTopics]; for (int language = 0; language < numLanguages; language++) { int[] oneDocTopics = topicAssignment.topicSequences[language].getFeatures(); int docLength = topicAssignment.topicSequences[language].getLength(); // populate topic counts for (int position = 0; position < docLength; position++) { localTopicCounts[oneDocTopics[position]]++; } } // Build an array that densely lists the topics that // have non-zero counts. int denseIndex = 0; for (int topic = 0; topic < numTopics; topic++) { if (localTopicCounts[topic] != 0) { localTopicIndex[denseIndex] = topic; denseIndex++; } } // Record the total number of non-zero topics int nonZeroTopics = denseIndex; for (int language = 0; language < numLanguages; language++) { int[] oneDocTopics = topicAssignment.topicSequences[language].getFeatures(); int docLength = topicAssignment.topicSequences[language].getLength(); FeatureSequence tokenSequence = (FeatureSequence) topicAssignment.instances[language].getData(); int[][] typeTopicCounts = languageTypeTopicCounts[language]; int[] tokensPerTopic = languageTokensPerTopic[language]; double beta = betas[language]; double betaSum = betaSums[language]; // Initialize the smoothing-only sampling bucket double smoothingOnlyMass = languageSmoothingOnlyMasses[language]; //for (int topic = 0; topic < numTopics; topic++) //smoothingOnlyMass += alpha[topic] * beta / (tokensPerTopic[topic] + betaSum); // Initialize the cached coefficients, using only smoothing. //cachedCoefficients = new double[ numTopics ]; //for (int topic=0; topic < numTopics; topic++) // cachedCoefficients[topic] = alpha[topic] / (tokensPerTopic[topic] + betaSum); double[] cachedCoefficients = languageCachedCoefficients[language]; // Initialize the topic count/beta sampling bucket double topicBetaMass = 0.0; // Initialize cached coefficients and the topic/beta // normalizing constant. for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) { int topic = localTopicIndex[denseIndex]; int n = localTopicCounts[topic]; // initialize the normalization constant for the (B * n_{t|d}) term topicBetaMass += beta * n / (tokensPerTopic[topic] + betaSum); // update the coefficients for the non-zero topics cachedCoefficients[topic] = (alpha[topic] + n) / (tokensPerTopic[topic] + betaSum); } double topicTermMass = 0.0; double[] topicTermScores = new double[numTopics]; int[] topicTermIndices; int[] topicTermValues; int i; double score; // Iterate over the positions (words) in the document for (int position = 0; position < docLength; position++) { type = tokenSequence.getIndexAtPosition(position); oldTopic = oneDocTopics[position]; if (oldTopic == -1) { continue; } currentTypeTopicCounts = typeTopicCounts[type]; // Remove this token from all counts. // Remove this topic's contribution to the // normalizing constants smoothingOnlyMass -= alpha[oldTopic] * beta / (tokensPerTopic[oldTopic] + betaSum); topicBetaMass -= beta * localTopicCounts[oldTopic] / (tokensPerTopic[oldTopic] + betaSum); // Decrement the local doc/topic counts localTopicCounts[oldTopic]--; // Maintain the dense index, if we are deleting // the old topic if (localTopicCounts[oldTopic] == 0) { // First get to the dense location associated with // the old topic. denseIndex = 0; // We know it's in there somewhere, so we don't // need bounds checking. while (localTopicIndex[denseIndex] != oldTopic) { denseIndex++; } // shift all remaining dense indices to the left. while (denseIndex < nonZeroTopics) { if (denseIndex < localTopicIndex.length - 1) { localTopicIndex[denseIndex] = localTopicIndex[denseIndex + 1]; } denseIndex++; } nonZeroTopics --; } // Decrement the global topic count totals tokensPerTopic[oldTopic]--; //assert(tokensPerTopic[oldTopic] >= 0) : "old Topic " + oldTopic + " below 0"; // Add the old topic's contribution back into the // normalizing constants. smoothingOnlyMass += alpha[oldTopic] * beta / (tokensPerTopic[oldTopic] + betaSum); topicBetaMass += beta * localTopicCounts[oldTopic] / (tokensPerTopic[oldTopic] + betaSum); // Reset the cached coefficient for this topic cachedCoefficients[oldTopic] = (alpha[oldTopic] + localTopicCounts[oldTopic]) / (tokensPerTopic[oldTopic] + betaSum); // Now go over the type/topic counts, decrementing // where appropriate, and calculating the score // for each topic at the same time. int index = 0; int currentTopic, currentValue; boolean alreadyDecremented = false; topicTermMass = 0.0; while (index < currentTypeTopicCounts.length && currentTypeTopicCounts[index] > 0) { currentTopic = currentTypeTopicCounts[index] & topicMask; currentValue = currentTypeTopicCounts[index] >> topicBits; if (! alreadyDecremented && currentTopic == oldTopic) { // We're decrementing and adding up the // sampling weights at the same time, but // decrementing may require us to reorder // the topics, so after we're done here, // look at this cell in the array again. currentValue --; if (currentValue == 0) { currentTypeTopicCounts[index] = 0; } else { currentTypeTopicCounts[index] = (currentValue << topicBits) + oldTopic; } // Shift the reduced value to the right, if necessary. int subIndex = index; while (subIndex < currentTypeTopicCounts.length - 1 && currentTypeTopicCounts[subIndex] < currentTypeTopicCounts[subIndex + 1]) { int temp = currentTypeTopicCounts[subIndex]; currentTypeTopicCounts[subIndex] = currentTypeTopicCounts[subIndex + 1]; currentTypeTopicCounts[subIndex + 1] = temp; subIndex++; } alreadyDecremented = true; } else { score = cachedCoefficients[currentTopic] * currentValue; topicTermMass += score; topicTermScores[index] = score; index++; } } double sample = random.nextUniform() * (smoothingOnlyMass + topicBetaMass + topicTermMass); double origSample = sample; // Make sure it actually gets set newTopic = -1; if (sample < topicTermMass) { //topicTermCount++; i = -1; while (sample > 0) { i++; sample -= topicTermScores[i]; } newTopic = currentTypeTopicCounts[i] & topicMask; currentValue = currentTypeTopicCounts[i] >> topicBits; currentTypeTopicCounts[i] = ((currentValue + 1) << topicBits) + newTopic; // Bubble the new value up, if necessary while (i > 0 && currentTypeTopicCounts[i] > currentTypeTopicCounts[i - 1]) { int temp = currentTypeTopicCounts[i]; currentTypeTopicCounts[i] = currentTypeTopicCounts[i - 1]; currentTypeTopicCounts[i - 1] = temp; i--; } } else { sample -= topicTermMass; if (sample < topicBetaMass) { //betaTopicCount++; sample /= beta; for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) { int topic = localTopicIndex[denseIndex]; sample -= localTopicCounts[topic] / (tokensPerTopic[topic] + betaSum); if (sample <= 0.0) { newTopic = topic; break; } } } else { //smoothingOnlyCount++; sample -= topicBetaMass; sample /= beta; newTopic = 0; sample -= alpha[newTopic] / (tokensPerTopic[newTopic] + betaSum); while (sample > 0.0) { newTopic++; sample -= alpha[newTopic] / (tokensPerTopic[newTopic] + betaSum); } } // Move to the position for the new topic, // which may be the first empty position if this // is a new topic for this word. index = 0; while (currentTypeTopicCounts[index] > 0 && (currentTypeTopicCounts[index] & topicMask) != newTopic) { index++; } // index should now be set to the position of the new topic, // which may be an empty cell at the end of the list. if (currentTypeTopicCounts[index] == 0) { // inserting a new topic, guaranteed to be in // order w.r.t. count, if not topic. currentTypeTopicCounts[index] = (1 << topicBits) + newTopic; } else { currentValue = currentTypeTopicCounts[index] >> topicBits; currentTypeTopicCounts[index] = ((currentValue + 1) << topicBits) + newTopic; // Bubble the increased value left, if necessary while (index > 0 && currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) { int temp = currentTypeTopicCounts[index]; currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1]; currentTypeTopicCounts[index - 1] = temp; index--; } } } if (newTopic == -1) { System.err.println("PolylingualTopicModel sampling error: "+ origSample + " " + sample + " " + smoothingOnlyMass + " " + topicBetaMass + " " + topicTermMass); newTopic = numTopics-1; // TODO is this appropriate //throw new IllegalStateException ("PolylingualTopicModel: New topic not sampled."); } //assert(newTopic != -1); // Put that new topic into the counts oneDocTopics[position] = newTopic; smoothingOnlyMass -= alpha[newTopic] * beta / (tokensPerTopic[newTopic] + betaSum); topicBetaMass -= beta * localTopicCounts[newTopic] / (tokensPerTopic[newTopic] + betaSum); localTopicCounts[newTopic]++; // If this is a new topic for this document, // add the topic to the dense index. if (localTopicCounts[newTopic] == 1) { // First find the point where we // should insert the new topic by going to // the end (which is the only reason we're keeping // track of the number of non-zero // topics) and working backwards denseIndex = nonZeroTopics; while (denseIndex > 0 && localTopicIndex[denseIndex - 1] > newTopic) { localTopicIndex[denseIndex] = localTopicIndex[denseIndex - 1]; denseIndex--; } localTopicIndex[denseIndex] = newTopic; nonZeroTopics++; } tokensPerTopic[newTopic]++; // update the coefficients for the non-zero topics cachedCoefficients[newTopic] = (alpha[newTopic] + localTopicCounts[newTopic]) / (tokensPerTopic[newTopic] + betaSum); smoothingOnlyMass += alpha[newTopic] * beta / (tokensPerTopic[newTopic] + betaSum); topicBetaMass += beta * localTopicCounts[newTopic] / (tokensPerTopic[newTopic] + betaSum); // Save the smoothing-only mass to the global cache languageSmoothingOnlyMasses[language] = smoothingOnlyMass; } } if (shouldSaveState) { // Update the document-topic count histogram, // for dirichlet estimation int totalLength = 0; for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) { int topic = localTopicIndex[denseIndex]; topicDocCounts[topic][ localTopicCounts[topic] ]++; totalLength += localTopicCounts[topic]; } docLengthCounts[ totalLength ]++; } } public void printTopWords (File file, int numWords, boolean useNewLines) throws IOException { PrintStream out = new PrintStream (file); printTopWords(out, numWords, useNewLines); out.close(); } public void printTopWords (PrintStream out, int numWords, boolean usingNewLines) { TreeSet[][] languageTopicSortedWords = new TreeSet[numLanguages][numTopics]; for (int language = 0; language < numLanguages; language++) { TreeSet[] topicSortedWords = languageTopicSortedWords[language]; int[][] typeTopicCounts = languageTypeTopicCounts[language]; for (int topic = 0; topic < numTopics; topic++) { topicSortedWords[topic] = new TreeSet<IDSorter>(); } for (int type = 0; type < vocabularySizes[language]; type++) { int[] topicCounts = typeTopicCounts[type]; int index = 0; while (index < topicCounts.length && topicCounts[index] > 0) { int topic = topicCounts[index] & topicMask; int count = topicCounts[index] >> topicBits; topicSortedWords[topic].add(new IDSorter(type, count)); index++; } } } for (int topic = 0; topic < numTopics; topic++) { out.println (topic + "\t" + formatter.format(alpha[topic])); for (int language = 0; language < numLanguages; language++) { out.print(" " + language + "\t" + languageTokensPerTopic[language][topic] + "\t" + betas[language] + "\t"); TreeSet<IDSorter> sortedWords = languageTopicSortedWords[language][topic]; Alphabet alphabet = alphabets[language]; int word = 1; Iterator<IDSorter> iterator = sortedWords.iterator(); while (iterator.hasNext() && word < numWords) { IDSorter info = iterator.next(); out.print(alphabet.lookupObject(info.getID()) + " "); word++; } out.println(); } } } public void printDocumentTopics (File f) throws IOException { printDocumentTopics (new PrintWriter (f, "UTF-8") ); } public void printDocumentTopics (PrintWriter pw) { printDocumentTopics (pw, 0.0, -1); } /** * @param pw A print writer * @param threshold Only print topics with proportion greater than this number * @param max Print no more than this many topics */ public void printDocumentTopics (PrintWriter pw, double threshold, int max) { pw.print ("#doc source topic proportion ...\n"); int docLength; int[] topicCounts = new int[ numTopics ]; IDSorter[] sortedTopics = new IDSorter[ numTopics ]; for (int topic = 0; topic < numTopics; topic++) { // Initialize the sorters with dummy values sortedTopics[topic] = new IDSorter(topic, topic); } if (max < 0 || max > numTopics) { max = numTopics; } for (int di = 0; di < data.size(); di++) { pw.print (di); pw.print (' '); int totalLength = 0; for (int language = 0; language < numLanguages; language++) { LabelSequence topicSequence = (LabelSequence) data.get(di).topicSequences[language]; int[] currentDocTopics = topicSequence.getFeatures(); docLength = topicSequence.getLength(); totalLength += docLength; // Count up the tokens for (int token=0; token < docLength; token++) { topicCounts[ currentDocTopics[token] ]++; } } // And normalize for (int topic = 0; topic < numTopics; topic++) { sortedTopics[topic].set(topic, (float) topicCounts[topic] / totalLength); } Arrays.sort(sortedTopics); for (int i = 0; i < max; i++) { if (sortedTopics[i].getWeight() < threshold) { break; } pw.print (sortedTopics[i].getID() + " " + sortedTopics[i].getWeight() + " "); } pw.print (" \n"); Arrays.fill(topicCounts, 0); } } public void printState (File f) throws IOException { PrintStream out = new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))), false, "UTF-8"); printState(out); out.close(); } public void printState (PrintStream out) { out.println ("#doc lang pos typeindex type topic"); for (int doc = 0; doc < data.size(); doc++) { for (int language =0; language < numLanguages; language++) { FeatureSequence tokenSequence = (FeatureSequence) data.get(doc).instances[language].getData(); LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequences[language]; for (int pi = 0; pi < topicSequence.getLength(); pi++) { int type = tokenSequence.getIndexAtPosition(pi); int topic = topicSequence.getIndexAtPosition(pi); out.print(doc); out.print(' '); out.print(language); out.print(' '); out.print(pi); out.print(' '); out.print(type); out.print(' '); out.print(alphabets[language].lookupObject(type)); out.print(' '); out.print(topic); out.println(); } } } } public double modelLogLikelihood() { double logLikelihood = 0.0; int nonZeroTopics; // The likelihood of the model is a combination of a // Dirichlet-multinomial for the words in each topic // and a Dirichlet-multinomial for the topics in each // document. // The likelihood function of a dirichlet multinomial is // Gamma( sum_i alpha_i ) prod_i Gamma( alpha_i + N_i ) // prod_i Gamma( alpha_i ) Gamma( sum_i (alpha_i + N_i) ) // So the log likelihood is // logGamma ( sum_i alpha_i ) - logGamma ( sum_i (alpha_i + N_i) ) + // sum_i [ logGamma( alpha_i + N_i) - logGamma( alpha_i ) ] // Do the documents first int[] topicCounts = new int[numTopics]; double[] topicLogGammas = new double[numTopics]; int[] docTopics; for (int topic=0; topic < numTopics; topic++) { topicLogGammas[ topic ] = Dirichlet.logGammaStirling( alpha[topic] ); } for (int doc=0; doc < data.size(); doc++) { int totalLength = 0; for (int language = 0; language < numLanguages; language++) { LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequences[language]; int[] currentDocTopics = topicSequence.getFeatures(); totalLength += topicSequence.getLength(); // Count up the tokens for (int token=0; token < topicSequence.getLength(); token++) { topicCounts[ currentDocTopics[token] ]++; } } for (int topic=0; topic < numTopics; topic++) { if (topicCounts[topic] > 0) { logLikelihood += (Dirichlet.logGammaStirling(alpha[topic] + topicCounts[topic]) - topicLogGammas[ topic ]); } } // subtract the (count + parameter) sum term logLikelihood -= Dirichlet.logGammaStirling(alphaSum + totalLength); Arrays.fill(topicCounts, 0); } // add the parameter sum term logLikelihood += data.size() * Dirichlet.logGammaStirling(alphaSum); // And the topics for (int language = 0; language < numLanguages; language++) { int[][] typeTopicCounts = languageTypeTopicCounts[language]; int[] tokensPerTopic = languageTokensPerTopic[language]; double beta = betas[language]; // Count the number of type-topic pairs int nonZeroTypeTopics = 0; for (int type=0; type < vocabularySizes[language]; type++) { // reuse this array as a pointer topicCounts = typeTopicCounts[type]; int index = 0; while (index < topicCounts.length && topicCounts[index] > 0) { int topic = topicCounts[index] & topicMask; int count = topicCounts[index] >> topicBits; nonZeroTypeTopics++; logLikelihood += Dirichlet.logGammaStirling(beta + count); if (Double.isNaN(logLikelihood)) { System.out.println(count); System.exit(1); } index++; } } for (int topic=0; topic < numTopics; topic++) { logLikelihood -= Dirichlet.logGammaStirling( (beta * numTopics) + tokensPerTopic[ topic ] ); if (Double.isNaN(logLikelihood)) { System.out.println("after topic " + topic + " " + tokensPerTopic[ topic ]); System.exit(1); } } logLikelihood += (Dirichlet.logGammaStirling(beta * numTopics)) - (Dirichlet.logGammaStirling(beta) * nonZeroTypeTopics); } if (Double.isNaN(logLikelihood)) { System.out.println("at the end"); System.exit(1); } return logLikelihood; } /** Return a tool for estimating topic distributions for new documents */ public TopicInferencer getInferencer(int language) { return new TopicInferencer(languageTypeTopicCounts[language], languageTokensPerTopic[language], alphabets[language], alpha, betas[language], betaSums[language]); } // Serialization private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private static final int NULL_INTEGER = -1; private void writeObject (ObjectOutputStream out) throws IOException { out.writeInt (CURRENT_SERIAL_VERSION); out.writeInt(numLanguages); out.writeObject(data); out.writeObject(topicAlphabet); out.writeInt(numTopics); out.writeObject(testingIDs); out.writeInt(topicMask); out.writeInt(topicBits); out.writeObject(alphabets); out.writeObject(vocabularySizes); out.writeObject(alpha); out.writeDouble(alphaSum); out.writeObject(betas); out.writeObject(betaSums); out.writeObject(languageMaxTypeCounts); out.writeObject(languageTypeTopicCounts); out.writeObject(languageTokensPerTopic); out.writeObject(languageSmoothingOnlyMasses); out.writeObject(languageCachedCoefficients); out.writeObject(docLengthCounts); out.writeObject(topicDocCounts); out.writeInt(numIterations); out.writeInt(burninPeriod); out.writeInt(saveSampleInterval); out.writeInt(optimizeInterval); out.writeInt(showTopicsInterval); out.writeInt(wordsPerTopic); out.writeInt(saveStateInterval); out.writeObject(stateFilename); out.writeInt(saveModelInterval); out.writeObject(modelFilename); out.writeObject(random); out.writeObject(formatter); out.writeBoolean(printLogLikelihood); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt(); numLanguages = in.readInt(); data = (ArrayList<TopicAssignment>) in.readObject (); topicAlphabet = (LabelAlphabet) in.readObject(); numTopics = in.readInt(); testingIDs = (HashSet<String>) in.readObject(); topicMask = in.readInt(); topicBits = in.readInt(); alphabets = (Alphabet[]) in.readObject(); vocabularySizes = (int[]) in.readObject(); alpha = (double[]) in.readObject(); alphaSum = in.readDouble(); betas = (double[]) in.readObject(); betaSums = (double[]) in.readObject(); languageMaxTypeCounts = (int[]) in.readObject(); languageTypeTopicCounts = (int[][][]) in.readObject(); languageTokensPerTopic = (int[][]) in.readObject(); languageSmoothingOnlyMasses = (double[]) in.readObject(); languageCachedCoefficients = (double[][]) in.readObject(); docLengthCounts = (int[]) in.readObject(); topicDocCounts = (int[][]) in.readObject(); numIterations = in.readInt(); burninPeriod = in.readInt(); saveSampleInterval = in.readInt(); optimizeInterval = in.readInt(); showTopicsInterval = in.readInt(); wordsPerTopic = in.readInt(); saveStateInterval = in.readInt(); stateFilename = (String) in.readObject(); saveModelInterval = in.readInt(); modelFilename = (String) in.readObject(); random = (Randoms) in.readObject(); formatter = (NumberFormat) in.readObject(); printLogLikelihood = in.readBoolean(); } public void write (File serializedModelFile) { try { ObjectOutputStream oos = new ObjectOutputStream (new FileOutputStream(serializedModelFile)); oos.writeObject(this); oos.close(); } catch (IOException e) { System.err.println("Problem serializing PolylingualTopicModel to file " + serializedModelFile + ": " + e); } } public static PolylingualTopicModel read (File f) throws Exception { PolylingualTopicModel topicModel = null; ObjectInputStream ois = new ObjectInputStream (new FileInputStream(f)); topicModel = (PolylingualTopicModel) ois.readObject(); ois.close(); topicModel.initializeHistograms(); return topicModel; } public static void main (String[] args) throws IOException { CommandOption.setSummary (PolylingualTopicModel.class, "A tool for estimating, saving and printing diagnostics for topic models over comparable corpora."); CommandOption.process (PolylingualTopicModel.class, args); PolylingualTopicModel topicModel = null; if (inputModelFilename.value != null) { try { topicModel = PolylingualTopicModel.read(new File(inputModelFilename.value)); } catch (Exception e) { System.err.println("Unable to restore saved topic model " + inputModelFilename.value + ": " + e); System.exit(1); } } else { int numLanguages = languageInputFiles.value.length; InstanceList[] training = new InstanceList[ numLanguages ]; for (int i=0; i < training.length; i++) { training[i] = InstanceList.load(new File(languageInputFiles.value[i])); if (training[i] != null) { System.out.println(i + " is not null"); } else { System.out.println(i + " is null"); } } System.out.println ("Data loaded."); // For historical reasons we currently only support FeatureSequence data, // not the FeatureVector, which is the default for the input functions. // Provide a warning to avoid ClassCastExceptions. if (training[0].size() > 0 && training[0].get(0) != null) { Object data = training[0].get(0).getData(); if (! (data instanceof FeatureSequence)) { System.err.println("Topic modeling currently only supports feature sequences: use --keep-sequence option when importing data."); System.exit(1); } } topicModel = new PolylingualTopicModel (numTopicsOption.value, alphaOption.value); if (randomSeedOption.value != 0) { topicModel.setRandomSeed(randomSeedOption.value); } topicModel.addInstances(training); } topicModel.setTopicDisplay(showTopicsIntervalOption.value, topWordsOption.value); topicModel.setNumIterations(numIterationsOption.value); topicModel.setOptimizeInterval(optimizeIntervalOption.value); topicModel.setBurninPeriod(optimizeBurnInOption.value); if (outputStateIntervalOption.value != 0) { topicModel.setSaveState(outputStateIntervalOption.value, stateFile.value); } if (outputModelIntervalOption.value != 0) { topicModel.setModelOutput(outputModelIntervalOption.value, outputModelFilename.value); } topicModel.estimate(); if (topicKeysFile.value != null) { topicModel.printTopWords(new File(topicKeysFile.value), topWordsOption.value, false); } if (stateFile.value != null) { topicModel.printState (new File(stateFile.value)); } if (docTopicsFile.value != null) { PrintWriter out = new PrintWriter (new FileWriter ((new File(docTopicsFile.value)))); topicModel.printDocumentTopics(out, docTopicsThreshold.value, docTopicsMax.value); out.close(); } if (inferencerFilename.value != null) { try { for (int language = 0; language < topicModel.numLanguages; language++) { ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(inferencerFilename.value + "." + language)); oos.writeObject(topicModel.getInferencer(language)); oos.close(); } } catch (Exception e) { System.err.println(e.getMessage()); } } if (outputModelFilename.value != null) { assert (topicModel != null); topicModel.write(new File(outputModelFilename.value)); } } }