package cc.mallet.topics; import cc.mallet.types.*; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.text.NumberFormat; import java.util.ArrayList; import java.util.logging.FileHandler; import java.util.logging.Handler; import java.util.logging.SimpleFormatter; /** * Created with IntelliJ IDEA. * User: oyiptong * Date: 2012-08-10 * Time: 11:35 AM */ public class PersistentParallelTopicModel extends ParallelTopicModel { // The number of times each type appears in the corpus int[] typeTotals; // The max over typeTotals, used for beta optimization int maxTypeCount; int numThreads = 1; public PersistentParallelTopicModel(int numberOfTopics) { super(numberOfTopics, numberOfTopics, DEFAULT_BETA); } public PersistentParallelTopicModel(int numberOfTopics, double alphaSum, double beta) { super(newLabelAlphabet(numberOfTopics), alphaSum, beta); try { Handler handler = new FileHandler("training.log", 10485760, 10); handler.setFormatter(new SimpleFormatter()); logger.setUseParentHandlers(false); logger.addHandler(handler); } catch (IOException e) { // default stdout logging } } private static LabelAlphabet newLabelAlphabet (int numTopics) { LabelAlphabet ret = new LabelAlphabet(); for (int i = 0; i < numTopics; i++) ret.lookupIndex("topic"+i); return ret; } /** * Gather statistics on the size of documents * and create histograms for use in Dirichlet hyperparameter * optimization. */ private void initializeHistograms() { int maxTokens = 0; totalTokens = 0; int seqLen; for (int doc = 0; doc < data.size(); doc++) { FeatureSequence fs = (FeatureSequence) data.get(doc).instance.getData(); seqLen = fs.getLength(); if (seqLen > maxTokens) maxTokens = seqLen; totalTokens += seqLen; } logger.info("max tokens: " + maxTokens); logger.info("total tokens: " + totalTokens); docLengthCounts = new int[maxTokens + 1]; topicDocCounts = new int[numTopics][maxTokens + 1]; } // Serialization private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private static final int NULL_INTEGER = -1; private void writeObject (ObjectOutputStream out) throws IOException { out.writeInt (CURRENT_SERIAL_VERSION); out.writeObject(data); out.writeObject(alphabet); out.writeObject(topicAlphabet); out.writeInt(numTopics); out.writeInt(topicMask); out.writeInt(topicBits); out.writeInt(numTypes); out.writeObject(alpha); out.writeDouble(alphaSum); out.writeDouble(beta); out.writeDouble(betaSum); out.writeObject(typeTopicCounts); out.writeObject(tokensPerTopic); out.writeObject(docLengthCounts); out.writeObject(topicDocCounts); out.writeInt(numIterations); out.writeInt(burninPeriod); out.writeInt(saveSampleInterval); out.writeInt(optimizeInterval); out.writeInt(showTopicsInterval); out.writeInt(wordsPerTopic); out.writeInt(saveStateInterval); out.writeObject(stateFilename); out.writeInt(saveModelInterval); out.writeObject(modelFilename); out.writeInt(randomSeed); out.writeObject(formatter); out.writeBoolean(printLogLikelihood); out.writeInt(numThreads); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt (); data = (ArrayList<TopicAssignment>) in.readObject (); alphabet = (Alphabet) in.readObject(); topicAlphabet = (LabelAlphabet) in.readObject(); numTopics = in.readInt(); topicMask = in.readInt(); topicBits = in.readInt(); numTypes = in.readInt(); alpha = (double[]) in.readObject(); alphaSum = in.readDouble(); beta = in.readDouble(); betaSum = in.readDouble(); typeTopicCounts = (int[][]) in.readObject(); tokensPerTopic = (int[]) in.readObject(); docLengthCounts = (int[]) in.readObject(); topicDocCounts = (int[][]) in.readObject(); numIterations = in.readInt(); burninPeriod = in.readInt(); saveSampleInterval = in.readInt(); optimizeInterval = in.readInt(); showTopicsInterval = in.readInt(); wordsPerTopic = in.readInt(); saveStateInterval = in.readInt(); stateFilename = (String) in.readObject(); saveModelInterval = in.readInt(); modelFilename = (String) in.readObject(); randomSeed = in.readInt(); formatter = (NumberFormat) in.readObject(); printLogLikelihood = in.readBoolean(); numThreads = in.readInt(); } public double[][] getNormalizedDocumentTopicWeights() { int docLen; int[] topicCounts = new int[ numTopics ]; double[][] documentWeights = new double[data.size()][numTopics]; for (int docIndex=0; docIndex < data.size(); docIndex++) { TopicAssignment doc = data.get(docIndex); LabelSequence topicSequence = (LabelSequence) doc.topicSequence; int[] currentDocTopics = topicSequence.getFeatures(); docLen = currentDocTopics.length; // Count up the tokens for (int token=0; token < docLen; token++) { topicCounts[ currentDocTopics[token] ]++; } // And normalize double weightSum = 0; for (int topicId = 0; topicId < numTopics; topicId++) { weightSum += (alpha[topicId] + topicCounts[topicId]) / (docLen + alphaSum); } // Save proportional topic weight for (int topicId = 0; topicId < numTopics; topicId++) { documentWeights[docIndex][topicId] = ((alpha[topicId] + topicCounts[topicId]) / (docLen + alphaSum))/weightSum; } } return documentWeights; } public static PersistentParallelTopicModel read (byte[] data) throws Exception { PersistentParallelTopicModel topicModel = null; ObjectInputStream ois = new ObjectInputStream (new ByteArrayInputStream(data)); topicModel = (PersistentParallelTopicModel) ois.readObject(); ois.close(); topicModel.initializeHistograms(); return topicModel; } public PancakeTopicInferencer getInferencer(){ return new PancakeTopicInferencer(typeTopicCounts, tokensPerTopic, data.get(0).instance.getDataAlphabet(), alpha, beta, betaSum); } }