package cc.mallet.topics; import cc.mallet.types.Alphabet; import cc.mallet.types.IDSorter; import cc.mallet.types.Instance; import cc.mallet.types.InstanceList; import cc.mallet.util.Randoms; import java.io.*; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; /** * Created with IntelliJ IDEA. * User: oyiptong * Date: 2012-08-15 * Time: 9:37 AM */ public class PancakeTopicInferencer extends TopicInferencer { Alphabet alphabet; double smoothingOnlyMass = 0.0; double[] cachedCoefficients; public PancakeTopicInferencer(int[][] typeTopicCounts, int[] tokensPerTopic, Alphabet alphabet, double[] alpha, double beta, double betaSum) { super(typeTopicCounts, tokensPerTopic, alphabet, alpha, beta, betaSum); } public List<List> inferSortedDistributions(InstanceList instances, int numIterations, int thinning, int burnIn, double threshold, int max){ IDSorter[] sortedTopics = new IDSorter[ numTopics ]; for (int topic = 0; topic < numTopics; topic++) { // Initialize the sorters with dummy values sortedTopics[topic] = new IDSorter(topic, topic); } if (max < 0 || max > numTopics) { max = numTopics; } ArrayList<List> output = new ArrayList<List>(); for (Instance instance: instances) { ArrayList document = new ArrayList(); ArrayList<List> topTopics = new ArrayList<List>(); double[] topicDistribution = getSampledDistribution(instance, numIterations, thinning, burnIn); for (int topic = 0; topic < numTopics; topic++) { sortedTopics[topic].set(topic, topicDistribution[topic]); } Arrays.sort(sortedTopics); for (int i = 0; i < max; i++) { if (sortedTopics[i].getWeight() < threshold) { break; } ArrayList topicDist = new ArrayList(); topicDist.add(sortedTopics[i].getID()); topicDist.add(sortedTopics[i].getWeight()); topTopics.add(topicDist); } document.add(instance.getName()); document.add(topTopics); output.add(document); } return output; } public List<List> inferDistributions(InstanceList instances, int numIterations, int thinning, int burnIn, double threshold){ ArrayList<List> output = new ArrayList<List>(); for (Instance instance: instances) { ArrayList document = new ArrayList(); double[] topicDistribution = getSampledDistribution(instance, numIterations, thinning, burnIn); document.add(instance.getName()); document.add(topicDistribution); output.add(document); } return output; } // Serialization private static final long serialVersionUID = 1; private static final int CURRENT_SERIAL_VERSION = 0; private static final int NULL_INTEGER = -1; private void writeObject (ObjectOutputStream out) throws IOException { out.writeInt (CURRENT_SERIAL_VERSION); out.writeObject(alphabet); out.writeInt(numTopics); out.writeInt(topicMask); out.writeInt(topicBits); out.writeInt(numTypes); out.writeObject(alpha); out.writeDouble(beta); out.writeDouble(betaSum); out.writeObject(typeTopicCounts); out.writeObject(tokensPerTopic); out.writeObject(random); out.writeDouble(smoothingOnlyMass); out.writeObject(cachedCoefficients); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { int version = in.readInt (); alphabet = (Alphabet) in.readObject(); numTopics = in.readInt(); topicMask = in.readInt(); topicBits = in.readInt(); numTypes = in.readInt(); alpha = (double[]) in.readObject(); beta = in.readDouble(); betaSum = in.readDouble(); typeTopicCounts = (int[][]) in.readObject(); tokensPerTopic = (int[]) in.readObject(); random = (Randoms) in.readObject(); smoothingOnlyMass = in.readDouble(); cachedCoefficients = (double[]) in.readObject(); } }