/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.topics;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.zip.*;
import java.io.*;
import java.text.NumberFormat;
import cc.mallet.types.*;
import cc.mallet.util.Randoms;
/**
* A parallel topic model runnable task.
*
* @author David Mimno, Andrew McCallum
*/
public class WorkerRunnable implements Runnable {
boolean isFinished = true;
ArrayList<TopicAssignment> data;
int startDoc, numDocs;
protected int numTopics; // Number of topics to be fit
// These values are used to encode type/topic counts as
// count/topic pairs in a single int.
protected int topicMask;
protected int topicBits;
protected int numTypes;
protected double[] alpha; // Dirichlet(alpha,alpha,...) is the distribution over topics
protected double alphaSum;
protected double beta; // Prior on per-topic multinomial distribution over words
protected double betaSum;
public static final double DEFAULT_BETA = 0.01;
protected double smoothingOnlyMass = 0.0;
protected double[] cachedCoefficients;
protected int[][] typeTopicCounts; // indexed by <feature index, topic index>
protected int[] tokensPerTopic; // indexed by <topic index>
// for dirichlet estimation
protected int[] docLengthCounts; // histogram of document sizes
protected int[][] topicDocCounts; // histogram of document/topic counts, indexed by <topic index, sequence position index>
boolean shouldSaveState = false;
boolean shouldBuildLocalCounts = true;
protected Randoms random;
public WorkerRunnable (int numTopics,
double[] alpha, double alphaSum,
double beta, Randoms random,
ArrayList<TopicAssignment> data,
int[][] typeTopicCounts,
int[] tokensPerTopic,
int startDoc, int numDocs) {
this.data = data;
this.numTopics = numTopics;
this.numTypes = typeTopicCounts.length;
if (Integer.bitCount(numTopics) == 1) {
// exact power of 2
topicMask = numTopics - 1;
topicBits = Integer.bitCount(topicMask);
}
else {
// otherwise add an extra bit
topicMask = Integer.highestOneBit(numTopics) * 2 - 1;
topicBits = Integer.bitCount(topicMask);
}
this.typeTopicCounts = typeTopicCounts;
this.tokensPerTopic = tokensPerTopic;
this.alphaSum = alphaSum;
this.alpha = alpha;
this.beta = beta;
this.betaSum = beta * numTypes;
this.random = random;
this.startDoc = startDoc;
this.numDocs = numDocs;
cachedCoefficients = new double[ numTopics ];
//System.err.println("WorkerRunnable Thread: " + numTopics + " topics, " + topicBits + " topic bits, " +
// Integer.toBinaryString(topicMask) + " topic mask");
}
/**
* If there is only one thread, we don't need to go through
* communication overhead. This method asks this worker not
* to prepare local type-topic counts. The method should be
* called when we are using this code in a non-threaded environment.
*/
public void makeOnlyThread() {
shouldBuildLocalCounts = false;
}
public int[] getTokensPerTopic() { return tokensPerTopic; }
public int[][] getTypeTopicCounts() { return typeTopicCounts; }
public int[] getDocLengthCounts() { return docLengthCounts; }
public int[][] getTopicDocCounts() { return topicDocCounts; }
public void initializeAlphaStatistics(int size) {
docLengthCounts = new int[size];
topicDocCounts = new int[numTopics][size];
}
public void collectAlphaStatistics() {
shouldSaveState = true;
}
public void resetBeta(double beta, double betaSum) {
this.beta = beta;
this.betaSum = betaSum;
}
/**
* Once we have sampled the local counts, trash the
* "global" type topic counts and reuse the space to
* build a summary of the type topic counts specific to
* this worker's section of the corpus.
*/
public void buildLocalTypeTopicCounts () {
// Clear the topic totals
Arrays.fill(tokensPerTopic, 0);
// Clear the type/topic counts, only
// looking at the entries before the first 0 entry.
for (int type = 0; type < typeTopicCounts.length; type++) {
int[] topicCounts = typeTopicCounts[type];
int position = 0;
while (position < topicCounts.length &&
topicCounts[position] > 0) {
topicCounts[position] = 0;
position++;
}
}
for (int doc = startDoc;
doc < data.size() && doc < startDoc + numDocs;
doc++) {
TopicAssignment document = data.get(doc);
FeatureSequence tokens = (FeatureSequence) document.instance.getData();
FeatureSequence topicSequence = (FeatureSequence) document.topicSequence;
int[] topics = topicSequence.getFeatures();
for (int position = 0; position < tokens.size(); position++) {
int topic = topics[position];
if (topic == ParallelTopicModel.UNASSIGNED_TOPIC) { continue; }
tokensPerTopic[topic]++;
// The format for these arrays is
// the topic in the rightmost bits
// the count in the remaining (left) bits.
// Since the count is in the high bits, sorting (desc)
// by the numeric value of the int guarantees that
// higher counts will be before the lower counts.
int type = tokens.getIndexAtPosition(position);
int[] currentTypeTopicCounts = typeTopicCounts[ type ];
// Start by assuming that the array is either empty
// or is in sorted (descending) order.
// Here we are only adding counts, so if we find
// an existing location with the topic, we only need
// to ensure that it is not larger than its left neighbor.
int index = 0;
int currentTopic = currentTypeTopicCounts[index] & topicMask;
int currentValue;
while (currentTypeTopicCounts[index] > 0 && currentTopic != topic) {
index++;
if (index == currentTypeTopicCounts.length) {
System.out.println("overflow on type " + type);
}
currentTopic = currentTypeTopicCounts[index] & topicMask;
}
currentValue = currentTypeTopicCounts[index] >> topicBits;
if (currentValue == 0) {
// new value is 1, so we don't have to worry about sorting
// (except by topic suffix, which doesn't matter)
currentTypeTopicCounts[index] =
(1 << topicBits) + topic;
}
else {
currentTypeTopicCounts[index] =
((currentValue + 1) << topicBits) + topic;
// Now ensure that the array is still sorted by
// bubbling this value up.
while (index > 0 &&
currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
int temp = currentTypeTopicCounts[index];
currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
currentTypeTopicCounts[index - 1] = temp;
index--;
}
}
}
}
}
public void run () {
try {
if (! isFinished) { System.out.println("already running!"); return; }
isFinished = false;
// Initialize the smoothing-only sampling bucket
smoothingOnlyMass = 0;
// Initialize the cached coefficients, using only smoothing.
// These values will be selectively replaced in documents with
// non-zero counts in particular topics.
for (int topic=0; topic < numTopics; topic++) {
smoothingOnlyMass += alpha[topic] * beta / (tokensPerTopic[topic] + betaSum);
cachedCoefficients[topic] = alpha[topic] / (tokensPerTopic[topic] + betaSum);
}
for (int doc = startDoc;
doc < data.size() && doc < startDoc + numDocs;
doc++) {
/*
if (doc % 10000 == 0) {
System.out.println("processing doc " + doc);
}
*/
FeatureSequence tokenSequence =
(FeatureSequence) data.get(doc).instance.getData();
LabelSequence topicSequence =
(LabelSequence) data.get(doc).topicSequence;
sampleTopicsForOneDoc (tokenSequence, topicSequence,
true);
}
if (shouldBuildLocalCounts) {
buildLocalTypeTopicCounts();
}
shouldSaveState = false;
isFinished = true;
} catch (Exception e) {
e.printStackTrace();
}
}
protected void sampleTopicsForOneDoc (FeatureSequence tokenSequence,
FeatureSequence topicSequence,
boolean readjustTopicsAndStats /* currently ignored */) {
int[] oneDocTopics = topicSequence.getFeatures();
int[] currentTypeTopicCounts;
int type, oldTopic, newTopic;
double topicWeightsSum;
int docLength = tokenSequence.getLength();
int[] localTopicCounts = new int[numTopics];
int[] localTopicIndex = new int[numTopics];
// populate topic counts
for (int position = 0; position < docLength; position++) {
if (oneDocTopics[position] == ParallelTopicModel.UNASSIGNED_TOPIC) { continue; }
localTopicCounts[oneDocTopics[position]]++;
}
// Build an array that densely lists the topics that
// have non-zero counts.
int denseIndex = 0;
for (int topic = 0; topic < numTopics; topic++) {
if (localTopicCounts[topic] != 0) {
localTopicIndex[denseIndex] = topic;
denseIndex++;
}
}
// Record the total number of non-zero topics
int nonZeroTopics = denseIndex;
// Initialize the topic count/beta sampling bucket
double topicBetaMass = 0.0;
// Initialize cached coefficients and the topic/beta
// normalizing constant.
for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
int topic = localTopicIndex[denseIndex];
int n = localTopicCounts[topic];
// initialize the normalization constant for the (B * n_{t|d}) term
topicBetaMass += beta * n / (tokensPerTopic[topic] + betaSum);
// update the coefficients for the non-zero topics
cachedCoefficients[topic] = (alpha[topic] + n) / (tokensPerTopic[topic] + betaSum);
}
double topicTermMass = 0.0;
double[] topicTermScores = new double[numTopics];
int[] topicTermIndices;
int[] topicTermValues;
int i;
double score;
// Iterate over the positions (words) in the document
for (int position = 0; position < docLength; position++) {
type = tokenSequence.getIndexAtPosition(position);
oldTopic = oneDocTopics[position];
currentTypeTopicCounts = typeTopicCounts[type];
if (oldTopic != ParallelTopicModel.UNASSIGNED_TOPIC) {
// Remove this token from all counts.
// Remove this topic's contribution to the
// normalizing constants
smoothingOnlyMass -= alpha[oldTopic] * beta /
(tokensPerTopic[oldTopic] + betaSum);
topicBetaMass -= beta * localTopicCounts[oldTopic] /
(tokensPerTopic[oldTopic] + betaSum);
// Decrement the local doc/topic counts
localTopicCounts[oldTopic]--;
// Maintain the dense index, if we are deleting
// the old topic
if (localTopicCounts[oldTopic] == 0) {
// First get to the dense location associated with
// the old topic.
denseIndex = 0;
// We know it's in there somewhere, so we don't
// need bounds checking.
while (localTopicIndex[denseIndex] != oldTopic) {
denseIndex++;
}
// shift all remaining dense indices to the left.
while (denseIndex < nonZeroTopics) {
if (denseIndex < localTopicIndex.length - 1) {
localTopicIndex[denseIndex] =
localTopicIndex[denseIndex + 1];
}
denseIndex++;
}
nonZeroTopics --;
}
// Decrement the global topic count totals
tokensPerTopic[oldTopic]--;
assert(tokensPerTopic[oldTopic] >= 0) : "old Topic " + oldTopic + " below 0";
// Add the old topic's contribution back into the
// normalizing constants.
smoothingOnlyMass += alpha[oldTopic] * beta /
(tokensPerTopic[oldTopic] + betaSum);
topicBetaMass += beta * localTopicCounts[oldTopic] /
(tokensPerTopic[oldTopic] + betaSum);
// Reset the cached coefficient for this topic
cachedCoefficients[oldTopic] =
(alpha[oldTopic] + localTopicCounts[oldTopic]) /
(tokensPerTopic[oldTopic] + betaSum);
}
// Now go over the type/topic counts, decrementing
// where appropriate, and calculating the score
// for each topic at the same time.
int index = 0;
int currentTopic, currentValue;
boolean alreadyDecremented = (oldTopic == ParallelTopicModel.UNASSIGNED_TOPIC);
topicTermMass = 0.0;
while (index < currentTypeTopicCounts.length &&
currentTypeTopicCounts[index] > 0) {
currentTopic = currentTypeTopicCounts[index] & topicMask;
currentValue = currentTypeTopicCounts[index] >> topicBits;
if (! alreadyDecremented &&
currentTopic == oldTopic) {
// We're decrementing and adding up the
// sampling weights at the same time, but
// decrementing may require us to reorder
// the topics, so after we're done here,
// look at this cell in the array again.
currentValue --;
if (currentValue == 0) {
currentTypeTopicCounts[index] = 0;
}
else {
currentTypeTopicCounts[index] =
(currentValue << topicBits) + oldTopic;
}
// Shift the reduced value to the right, if necessary.
int subIndex = index;
while (subIndex < currentTypeTopicCounts.length - 1 &&
currentTypeTopicCounts[subIndex] < currentTypeTopicCounts[subIndex + 1]) {
int temp = currentTypeTopicCounts[subIndex];
currentTypeTopicCounts[subIndex] = currentTypeTopicCounts[subIndex + 1];
currentTypeTopicCounts[subIndex + 1] = temp;
subIndex++;
}
alreadyDecremented = true;
}
else {
score =
cachedCoefficients[currentTopic] * currentValue;
topicTermMass += score;
topicTermScores[index] = score;
index++;
}
}
double sample = random.nextUniform() * (smoothingOnlyMass + topicBetaMass + topicTermMass);
double origSample = sample;
// Make sure it actually gets set
newTopic = -1;
if (sample < topicTermMass) {
//topicTermCount++;
i = -1;
while (sample > 0) {
i++;
sample -= topicTermScores[i];
}
newTopic = currentTypeTopicCounts[i] & topicMask;
currentValue = currentTypeTopicCounts[i] >> topicBits;
currentTypeTopicCounts[i] = ((currentValue + 1) << topicBits) + newTopic;
// Bubble the new value up, if necessary
while (i > 0 &&
currentTypeTopicCounts[i] > currentTypeTopicCounts[i - 1]) {
int temp = currentTypeTopicCounts[i];
currentTypeTopicCounts[i] = currentTypeTopicCounts[i - 1];
currentTypeTopicCounts[i - 1] = temp;
i--;
}
}
else {
sample -= topicTermMass;
if (sample < topicBetaMass) {
//betaTopicCount++;
sample /= beta;
for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
int topic = localTopicIndex[denseIndex];
sample -= localTopicCounts[topic] /
(tokensPerTopic[topic] + betaSum);
if (sample <= 0.0) {
newTopic = topic;
break;
}
}
}
else {
//smoothingOnlyCount++;
sample -= topicBetaMass;
sample /= beta;
newTopic = 0;
sample -= alpha[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
while (sample > 0.0) {
newTopic++;
sample -= alpha[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
}
}
// Move to the position for the new topic,
// which may be the first empty position if this
// is a new topic for this word.
index = 0;
while (currentTypeTopicCounts[index] > 0 &&
(currentTypeTopicCounts[index] & topicMask) != newTopic) {
index++;
if (index == currentTypeTopicCounts.length) {
System.err.println("type: " + type + " new topic: " + newTopic);
for (int k=0; k<currentTypeTopicCounts.length; k++) {
System.err.print((currentTypeTopicCounts[k] & topicMask) + ":" +
(currentTypeTopicCounts[k] >> topicBits) + " ");
}
System.err.println();
}
}
// index should now be set to the position of the new topic,
// which may be an empty cell at the end of the list.
if (currentTypeTopicCounts[index] == 0) {
// inserting a new topic, guaranteed to be in
// order w.r.t. count, if not topic.
currentTypeTopicCounts[index] = (1 << topicBits) + newTopic;
}
else {
currentValue = currentTypeTopicCounts[index] >> topicBits;
currentTypeTopicCounts[index] = ((currentValue + 1) << topicBits) + newTopic;
// Bubble the increased value left, if necessary
while (index > 0 &&
currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
int temp = currentTypeTopicCounts[index];
currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
currentTypeTopicCounts[index - 1] = temp;
index--;
}
}
}
if (newTopic == -1) {
System.err.println("WorkerRunnable sampling error: "+ origSample + " " + sample + " " + smoothingOnlyMass + " " +
topicBetaMass + " " + topicTermMass);
newTopic = numTopics-1; // TODO is this appropriate
//throw new IllegalStateException ("WorkerRunnable: New topic not sampled.");
}
//assert(newTopic != -1);
// Put that new topic into the counts
oneDocTopics[position] = newTopic;
smoothingOnlyMass -= alpha[newTopic] * beta /
(tokensPerTopic[newTopic] + betaSum);
topicBetaMass -= beta * localTopicCounts[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
localTopicCounts[newTopic]++;
// If this is a new topic for this document,
// add the topic to the dense index.
if (localTopicCounts[newTopic] == 1) {
// First find the point where we
// should insert the new topic by going to
// the end (which is the only reason we're keeping
// track of the number of non-zero
// topics) and working backwards
denseIndex = nonZeroTopics;
while (denseIndex > 0 &&
localTopicIndex[denseIndex - 1] > newTopic) {
localTopicIndex[denseIndex] =
localTopicIndex[denseIndex - 1];
denseIndex--;
}
localTopicIndex[denseIndex] = newTopic;
nonZeroTopics++;
}
tokensPerTopic[newTopic]++;
// update the coefficients for the non-zero topics
cachedCoefficients[newTopic] =
(alpha[newTopic] + localTopicCounts[newTopic]) /
(tokensPerTopic[newTopic] + betaSum);
smoothingOnlyMass += alpha[newTopic] * beta /
(tokensPerTopic[newTopic] + betaSum);
topicBetaMass += beta * localTopicCounts[newTopic] /
(tokensPerTopic[newTopic] + betaSum);
}
if (shouldSaveState) {
// Update the document-topic count histogram,
// for dirichlet estimation
docLengthCounts[ docLength ]++;
for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
int topic = localTopicIndex[denseIndex];
topicDocCounts[topic][ localTopicCounts[topic] ]++;
}
}
// Clean up our mess: reset the coefficients to values with only
// smoothing. The next doc will update its own non-zero topics...
for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
int topic = localTopicIndex[denseIndex];
cachedCoefficients[topic] =
alpha[topic] / (tokensPerTopic[topic] + betaSum);
}
}
}