package cc.mallet.topics.tui; import cc.mallet.util.CommandOption; import cc.mallet.util.Randoms; import cc.mallet.types.InstanceList; import cc.mallet.topics.HierarchicalLDA; import java.io.*; public class HierarchicalLDATUI { static CommandOption.String inputFile = new CommandOption.String (HierarchicalLDATUI.class, "input", "FILENAME", true, null, "The filename from which to read the list of training instances. Use - for stdin. " + "The instances must be FeatureSequence or FeatureSequenceWithBigrams, not FeatureVector", null); static CommandOption.String testingFile = new CommandOption.String (HierarchicalLDATUI.class, "testing", "FILENAME", true, null, "The filename from which to read the list of instances for held-out likelihood calculation. Use - for stdin. " + "The instances must be FeatureSequence or FeatureSequenceWithBigrams, not FeatureVector", null); static CommandOption.String stateFile = new CommandOption.String (HierarchicalLDATUI.class, "output-state", "FILENAME", true, null, "The filename in which to write the Gibbs sampling state after at the end of the iterations. " + "By default this is null, indicating that no file will be written.", null); static CommandOption.Integer randomSeed = new CommandOption.Integer (HierarchicalLDATUI.class, "random-seed", "INTEGER", true, 0, "The random seed for the Gibbs sampler. Default is 0, which will use the clock.", null); static CommandOption.Integer numIterations = new CommandOption.Integer (Vectors2Topics.class, "num-iterations", "INTEGER", true, 1000, "The number of iterations of Gibbs sampling.", null); static CommandOption.Boolean showProgress = new CommandOption.Boolean (HierarchicalLDATUI.class, "show-progress", "BOOLEAN", false, true, "If true, print a character to standard output after every sampling iteration.", null); static CommandOption.Integer showTopicsInterval = new CommandOption.Integer (HierarchicalLDATUI.class, "show-topics-interval", "INTEGER", true, 50, "The number of iterations between printing a brief summary of the topics so far.", null); static CommandOption.Integer topWords = new CommandOption.Integer (HierarchicalLDATUI.class, "num-top-words", "INTEGER", true, 20, "The number of most probable words to print for each topic after model estimation.", null); static CommandOption.Integer numLevels = new CommandOption.Integer (HierarchicalLDATUI.class, "num-levels", "INTEGER", true, 3, "The number of levels in the tree.", null); static CommandOption.Double alpha = new CommandOption.Double (HierarchicalLDATUI.class, "alpha", "DECIMAL", true, 10.0, "Alpha parameter: smoothing over level distributions.", null); static CommandOption.Double gamma = new CommandOption.Double (HierarchicalLDATUI.class, "gamma", "DECIMAL", true, 1.0, "Gamma parameter: CRP smoothing parameter; number of imaginary customers at next, as yet unused table", null); static CommandOption.Double eta = new CommandOption.Double (HierarchicalLDATUI.class, "eta", "DECIMAL", true, 0.1, "Eta parameter: smoothing over topic-word distributions", null); public static void main (String[] args) throws java.io.IOException { // Process the command-line options CommandOption.setSummary (HierarchicalLDATUI.class, "Hierarchical LDA with a fixed tree depth."); CommandOption.process (HierarchicalLDATUI.class, args); // Load instance lists if (inputFile.value() == null) { System.err.println("Input instance list is required, use --input option"); System.exit(1); } InstanceList instances = InstanceList.load(new File(inputFile.value())); InstanceList testing = null; if (testingFile.value() != null) { testing = InstanceList.load(new File(testingFile.value())); } HierarchicalLDA hlda = new HierarchicalLDA(); // Set hyperparameters hlda.setAlpha(alpha.value()); hlda.setGamma(gamma.value()); hlda.setEta(eta.value()); // Display preferences hlda.setTopicDisplay(showTopicsInterval.value(), topWords.value()); hlda.setProgressDisplay(showProgress.value()); // Initialize random number generator Randoms random = null; if (randomSeed.value() == 0) { random = new Randoms(); } else { random = new Randoms(randomSeed.value()); } // Initialize and start the sampler hlda.initialize(instances, testing, numLevels.value(), random); hlda.estimate(numIterations.value()); // Output results if (stateFile.value() != null) { hlda.printState(new PrintWriter(stateFile.value())); } if (testing != null) { double empiricalLikelihood = hlda.empiricalLikelihood(1000, testing); System.out.println("Empirical likelihood: " + empiricalLikelihood); } } }