package cc.mallet.topics; import cc.mallet.optimize.LimitedMemoryBFGS; import cc.mallet.optimize.OptimizationException; import cc.mallet.types.*; import cc.mallet.classify.MaxEnt; import cc.mallet.pipe.Pipe; import cc.mallet.pipe.Noop; import gnu.trove.TIntIntHashMap; import java.io.IOException; import java.io.PrintStream; import java.io.File; public class DMRTopicModel extends LDAHyper { MaxEnt dmrParameters = null; int numFeatures; int defaultFeatureIndex; Pipe parameterPipe = null; double[][] alphaCache; double[] alphaSumCache; public DMRTopicModel(int numberOfTopics) { super(numberOfTopics); } public void estimate (int iterationsThisRound) throws IOException { numFeatures = data.get(0).instance.getTargetAlphabet().size() + 1; defaultFeatureIndex = numFeatures - 1; int numDocs = data.size(); // TODO consider beginning by sub-sampling? alphaCache = new double[numDocs][numTopics]; alphaSumCache = new double[numDocs]; long startTime = System.currentTimeMillis(); int maxIteration = iterationsSoFar + iterationsThisRound; for ( ; iterationsSoFar <= maxIteration; iterationsSoFar++) { long iterationStart = System.currentTimeMillis(); if (showTopicsInterval != 0 && iterationsSoFar != 0 && iterationsSoFar % showTopicsInterval == 0) { System.out.println(); printTopWords (System.out, wordsPerTopic, false); } if (saveStateInterval != 0 && iterationsSoFar % saveStateInterval == 0) { this.printState(new File(stateFilename + '.' + iterationsSoFar + ".gz")); } if (iterationsSoFar > burninPeriod && optimizeInterval != 0 && iterationsSoFar % optimizeInterval == 0) { // Train regression parameters learnParameters(); } // Loop over every document in the corpus for (int doc = 0; doc < numDocs; doc++) { FeatureSequence tokenSequence = (FeatureSequence) data.get(doc).instance.getData(); LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequence; if (dmrParameters != null) { // set appropriate Alpha parameters setAlphas(data.get(doc).instance); } sampleTopicsForOneDoc (tokenSequence, topicSequence, false, false); } long ms = System.currentTimeMillis() - iterationStart; if (ms > 1000) { System.out.print(Math.round(ms / 1000) + "s "); } else { System.out.print(ms + "ms "); } if (iterationsSoFar % 10 == 0) { System.out.println ("<" + iterationsSoFar + "> "); if (printLogLikelihood) System.out.println (modelLogLikelihood()); } System.out.flush(); } long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0); long minutes = seconds / 60; seconds %= 60; long hours = minutes / 60; minutes %= 60; long days = hours / 24; hours %= 24; System.out.print ("\nTotal time: "); if (days != 0) { System.out.print(days); System.out.print(" days "); } if (hours != 0) { System.out.print(hours); System.out.print(" hours "); } if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); } System.out.print(seconds); System.out.println(" seconds"); } /** * Use only the default features to set the topic prior (use no document features) */ public void setAlphas() { double[] parameters = dmrParameters.getParameters(); alphaSum = 0.0; smoothingOnlyMass = 0.0; // Use only the default features to set the topic prior (use no document features) for (int topic=0; topic < numTopics; topic++) { alpha[topic] = Math.exp( parameters[ (topic * numFeatures) + defaultFeatureIndex ] ); alphaSum += alpha[topic]; smoothingOnlyMass += alpha[topic] * beta / (tokensPerTopic[topic] + betaSum); cachedCoefficients[topic] = alpha[topic] / (tokensPerTopic[topic] + betaSum); } } /** This method sets the alphas for a hypothetical "document" that contains * a single non-default feature. */ public void setAlphas(int featureIndex) { double[] parameters = dmrParameters.getParameters(); alphaSum = 0.0; smoothingOnlyMass = 0.0; // Use only the default features to set the topic prior (use no document features) for (int topic=0; topic < numTopics; topic++) { alpha[topic] = Math.exp(parameters[ (topic * numFeatures) + featureIndex ] + parameters[ (topic * numFeatures) + defaultFeatureIndex ] ); alphaSum += alpha[topic]; smoothingOnlyMass += alpha[topic] * beta / (tokensPerTopic[topic] + betaSum); cachedCoefficients[topic] = alpha[topic] / (tokensPerTopic[topic] + betaSum); } } /** * Set alpha based on features in an instance */ public void setAlphas(Instance instance) { // we can't use the standard score functions from MaxEnt, // since our features are currently in the Target. FeatureVector features = (FeatureVector) instance.getTarget(); if (features == null) { setAlphas(); return; } double[] parameters = dmrParameters.getParameters(); alphaSum = 0.0; smoothingOnlyMass = 0.0; for (int topic = 0; topic < numTopics; topic++) { alpha[topic] = parameters[topic*numFeatures + defaultFeatureIndex] + MatrixOps.rowDotProduct (parameters, numFeatures, topic, features, defaultFeatureIndex, null); alpha[topic] = Math.exp(alpha[topic]); alphaSum += alpha[topic]; smoothingOnlyMass += alpha[topic] * beta / (tokensPerTopic[topic] + betaSum); cachedCoefficients[topic] = alpha[topic] / (tokensPerTopic[topic] + betaSum); } } public void learnParameters() { // Create a "fake" pipe with the features in the data and // a trove int-int hashmap of topic counts in the target. if (parameterPipe == null) { parameterPipe = new Noop(); parameterPipe.setDataAlphabet(data.get(0).instance.getTargetAlphabet()); parameterPipe.setTargetAlphabet(topicAlphabet); } InstanceList parameterInstances = new InstanceList(parameterPipe); if (dmrParameters == null) { dmrParameters = new MaxEnt(parameterPipe, new double[numFeatures * numTopics]); } for (int doc=0; doc < data.size(); doc++) { if (data.get(doc).instance.getTarget() == null) { continue; } FeatureCounter counter = new FeatureCounter(topicAlphabet); for (int topic : data.get(doc).topicSequence.getFeatures()) { counter.increment(topic); } // Put the real target in the data field, and the // topic counts in the target field parameterInstances.add( new Instance(data.get(doc).instance.getTarget(), counter.toFeatureVector(), null, null) ); } DMROptimizable optimizable = new DMROptimizable(parameterInstances, dmrParameters); optimizable.setRegularGaussianPriorVariance(0.5); optimizable.setInterceptGaussianPriorVariance(100.0); LimitedMemoryBFGS optimizer = new LimitedMemoryBFGS(optimizable); // Optimize once try { optimizer.optimize(); } catch (OptimizationException e) { // step size too small } // Restart with a fresh initialization to improve likelihood try { optimizer.optimize(); } catch (OptimizationException e) { // step size too small } dmrParameters = optimizable.getClassifier(); for (int doc=0; doc < data.size(); doc++) { Instance instance = data.get(doc).instance; FeatureSequence tokens = (FeatureSequence) instance.getData(); if (instance.getTarget() == null) { continue; } int numTokens = tokens.getLength(); // This sets alpha[] and alphaSum setAlphas(instance); // Now cache alpha values for (int topic=0; topic < numTopics; topic++) { alphaCache[doc][topic] = alpha[topic]; } alphaSumCache[doc] = alphaSum; } } public void printTopWords (PrintStream out, int numWords, boolean usingNewLines) { if (dmrParameters != null) { setAlphas(); } super.printTopWords(out, numWords, usingNewLines); } public void writeParameters(File parameterFile) throws IOException { if (dmrParameters != null) { PrintStream out = new PrintStream(parameterFile); dmrParameters.print(out); out.close(); } } private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private static final int NULL_INTEGER = -1; public static void main (String[] args) throws IOException { InstanceList training = InstanceList.load (new File(args[0])); int numTopics = args.length > 1 ? Integer.parseInt(args[1]) : 200; InstanceList testing = args.length > 2 ? InstanceList.load (new File(args[2])) : null; DMRTopicModel lda = new DMRTopicModel (numTopics); lda.setOptimizeInterval(100); lda.setTopicDisplay(100, 10); lda.addInstances(training); lda.estimate(); lda.writeParameters(new File("dmr.parameters")); lda.printState(new File("dmr.state.gz")); } }