/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.clustering.lda;
import com.google.common.base.Preconditions;
import com.google.common.io.Closeables;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.IntPairWritable;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterator;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Random;
/**
* Estimates an LDA model from a corpus of documents, which are SparseVectors of word counts. At each phase,
* it outputs a matrix of log probabilities of each topic.
*/
public final class LDADriver extends AbstractJob {
private static final String TOPIC_SMOOTHING_OPTION = "topicSmoothing";
private static final String NUM_TOPICS_OPTION = "numTopics";
// TODO: sequential iteration is not yet correct.
// private static final String SEQUENTIAL_OPTION = "sequential";
static final String STATE_IN_KEY = "org.apache.mahout.clustering.lda.stateIn";
static final String NUM_TOPICS_KEY = "org.apache.mahout.clustering.lda.numTopics";
static final String NUM_WORDS_KEY = "org.apache.mahout.clustering.lda.numWords";
static final String TOPIC_SMOOTHING_KEY = "org.apache.mahout.clustering.lda.topicSmoothing";
static final int LOG_LIKELIHOOD_KEY = -2;
static final int TOPIC_SUM_KEY = -1;
static final double OVERALL_CONVERGENCE = 1.0E-5;
private static final Logger log = LoggerFactory.getLogger(LDADriver.class);
private LDAState state = null;
private LDAInference inference = null;
private Iterable<Pair<Writable, VectorWritable>> trainingCorpus = null;
public static void main(String[] args) throws Exception {
ToolRunner.run(new Configuration(), new LDADriver(), args);
}
public static LDAState createState(Configuration job) {
return createState(job, false);
}
public static LDAState createState(Configuration job, boolean empty) {
String statePath = job.get(STATE_IN_KEY);
int numTopics = Integer.parseInt(job.get(NUM_TOPICS_KEY));
int numWords = Integer.parseInt(job.get(NUM_WORDS_KEY));
double topicSmoothing = Double.parseDouble(job.get(TOPIC_SMOOTHING_KEY));
Path dir = new Path(statePath);
// TODO scalability bottleneck: numWords * numTopics * 8bytes for the driver *and* M/R classes
DenseMatrix pWgT = new DenseMatrix(numTopics, numWords);
double[] logTotals = new double[numTopics];
Arrays.fill(logTotals, Double.NEGATIVE_INFINITY);
double ll = 0.0;
if (empty) {
return new LDAState(numTopics, numWords, topicSmoothing, pWgT, logTotals, ll);
}
for (Pair<IntPairWritable,DoubleWritable> record
: new SequenceFileDirIterable<IntPairWritable, DoubleWritable>(new Path(dir, "part-*"),
PathType.GLOB,
null,
null,
true,
job)) {
IntPairWritable key = record.getFirst();
DoubleWritable value = record.getSecond();
int topic = key.getFirst();
int word = key.getSecond();
if (word == TOPIC_SUM_KEY) {
logTotals[topic] = value.get();
Preconditions.checkArgument(!Double.isInfinite(value.get()));
} else if (topic == LOG_LIKELIHOOD_KEY) {
ll = value.get();
} else {
Preconditions.checkArgument(topic >= 0, "topic should be non-negative, not %d", topic);
Preconditions.checkArgument(word >= 0, "word should be non-negative not %d", word);
Preconditions.checkArgument(pWgT.getQuick(topic, word) == 0.0);
pWgT.setQuick(topic, word, value.get());
Preconditions.checkArgument(!Double.isInfinite(pWgT.getQuick(topic, word)));
}
}
return new LDAState(numTopics, numWords, topicSmoothing, pWgT, logTotals, ll);
}
@Override
public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException {
addInputOption();
addOutputOption();
addOption(DefaultOptionCreator.overwriteOption().create());
addOption(NUM_TOPICS_OPTION, "k", "The total number of topics in the corpus", true);
addOption(TOPIC_SMOOTHING_OPTION, "a", "Topic smoothing parameter. Default is 50/numTopics.", "-1.0");
// addOption(SEQUENTIAL_OPTION, "seq", "Run sequentially (not Hadoop-based). Default is false.", "false");
addOption(DefaultOptionCreator.maxIterationsOption().withRequired(false).create());
if (parseArguments(args) == null) {
return -1;
}
Path input = getInputPath();
Path output = getOutputPath();
if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
HadoopUtil.delete(getConf(), output);
}
int maxIterations = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION));
int numTopics = Integer.parseInt(getOption(NUM_TOPICS_OPTION));
int numWords = determineNumberOfWordsFromFirstVector();
double topicSmoothing = Double.parseDouble(getOption(TOPIC_SMOOTHING_OPTION));
if (topicSmoothing < 1) {
topicSmoothing = 50.0 / numTopics;
}
boolean runSequential = false; // Boolean.parseBoolean(getOption(SEQUENTIAL_OPTION));
run(getConf(), input, output, numTopics, numWords, topicSmoothing, maxIterations, runSequential);
return 0;
}
private static Path getLastKnownStatePath(Configuration conf, Path stateDir) throws IOException {
FileSystem fs = FileSystem.get(conf);
Path lastPath = null;
int maxIteration = Integer.MIN_VALUE;
for (FileStatus fstatus : fs.globStatus(new Path(stateDir, "state-*"))) {
try {
int iteration = Integer.parseInt(fstatus.getPath().getName().split("-")[1]);
if(iteration > maxIteration) {
maxIteration = iteration;
lastPath = fstatus.getPath();
}
} catch(NumberFormatException nfe) {
throw new IOException(nfe);
}
}
return lastPath;
}
/**
* Determine the number of words based on the size of the input vectors.
* Note: can't just check first part since it might have null vector. (this
* is a possible when seq2sparse is run over a small dataset with a large number
* of reducers)
*/
private int determineNumberOfWordsFromFirstVector() throws IOException {
SequenceFileDirValueIterator<VectorWritable> it =
new SequenceFileDirValueIterator<VectorWritable>(getInputPath(),
PathType.LIST,
PathFilters.logsCRCFilter(),
null,
true,
getConf());
try {
while (it.hasNext()) {
VectorWritable v = it.next();
if (v.get() != null) {
return v.get().size();
}
}
} finally {
Closeables.closeQuietly(it);
}
log.warn("can't determine number of words; no vectors in {}", getInputPath());
return 0;
}
public double run(Configuration conf,
Path input,
Path output,
int numTopics,
int numWords,
double topicSmoothing,
int maxIterations,
boolean runSequential)
throws IOException, InterruptedException, ClassNotFoundException {
Path lastKnownState = getLastKnownStatePath(conf, output);
Path stateIn;
if (lastKnownState == null) {
stateIn = new Path(output, "state-0");
writeInitialState(stateIn, numTopics, numWords);
} else {
stateIn = lastKnownState;
}
conf.set(STATE_IN_KEY, stateIn.toString());
conf.set(NUM_TOPICS_KEY, Integer.toString(numTopics));
conf.set(NUM_WORDS_KEY, Integer.toString(numWords));
conf.set(TOPIC_SMOOTHING_KEY, Double.toString(topicSmoothing));
double oldLL = Double.NEGATIVE_INFINITY;
boolean converged = false;
int iteration = Integer.parseInt(stateIn.getName().split("-")[1]) + 1;
for (; ((maxIterations < 1) || (iteration <= maxIterations)) && !converged; iteration++) {
log.info("LDA Iteration {}", iteration);
conf.set(STATE_IN_KEY, stateIn.toString());
// point the output to a new directory per iteration
Path stateOut = new Path(output, "state-" + iteration);
double ll = runSequential
? runIterationSequential(conf, input, stateOut)
: runIteration(conf, input, stateIn, stateOut);
double relChange = (oldLL - ll) / oldLL;
// now point the input to the old output directory
log.info("Iteration {} finished. Log Likelihood: {}", iteration, ll);
log.info("(Old LL: {})", oldLL);
log.info("(Rel Change: {})", relChange);
converged = (iteration > 3) && (relChange < OVERALL_CONVERGENCE);
stateIn = stateOut;
oldLL = ll;
}
if(runSequential) {
computeDocumentTopicProbabilitiesSequential(conf, input, new Path(output, "docTopics"));
} else {
computeDocumentTopicProbabilities(conf,
input,
stateIn,
new Path(output, "docTopics"),
numTopics,
numWords,
topicSmoothing);
}
return -oldLL;
}
private static void writeInitialState(Path statePath, int numTopics, int numWords) throws IOException {
Configuration job = new Configuration();
FileSystem fs = statePath.getFileSystem(job);
DoubleWritable v = new DoubleWritable();
Random random = RandomUtils.getRandom();
for (int k = 0; k < numTopics; ++k) {
Path path = new Path(statePath, "part-" + k);
SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path, IntPairWritable.class, DoubleWritable.class);
try {
double total = 0.0; // total number of pseudo counts we made
for (int w = 0; w < numWords; ++w) {
Writable kw = new IntPairWritable(k, w);
// A small amount of random noise, minimized by having a floor.
double pseudocount = random.nextDouble() + 1.0E-8;
total += pseudocount;
v.set(Math.log(pseudocount));
writer.append(kw, v);
}
Writable kTsk = new IntPairWritable(k, TOPIC_SUM_KEY);
v.set(Math.log(total));
writer.append(kTsk, v);
} finally {
Closeables.closeQuietly(writer);
}
}
}
private static void writeState(Configuration job, LDAState state, Path statePath) throws IOException {
FileSystem fs = statePath.getFileSystem(job);
DoubleWritable v = new DoubleWritable();
for (int k = 0; k < state.getNumTopics(); ++k) {
Path path = new Path(statePath, "part-" + k);
SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path, IntPairWritable.class, DoubleWritable.class);
try {
for (int w = 0; w < state.getNumWords(); ++w) {
Writable kw = new IntPairWritable(k, w);
v.set(state.logProbWordGivenTopic(w,k) + state.getLogTotal(k));
writer.append(kw, v);
}
Writable kTsk = new IntPairWritable(k, TOPIC_SUM_KEY);
v.set(state.getLogTotal(k));
writer.append(kTsk, v);
} finally {
Closeables.closeQuietly(writer);
}
}
Path path = new Path(statePath, "part-" + LOG_LIKELIHOOD_KEY);
SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path, IntPairWritable.class, DoubleWritable.class);
try {
Writable kTsk = new IntPairWritable(LOG_LIKELIHOOD_KEY,LOG_LIKELIHOOD_KEY);
v.set(state.getLogLikelihood());
writer.append(kTsk, v);
} finally {
Closeables.closeQuietly(writer);
}
}
private static double findLL(Path statePath, Configuration job) throws IOException {
FileSystem fs = statePath.getFileSystem(job);
double ll = 0.0;
for (FileStatus status : fs.globStatus(new Path(statePath, "part-*"))) {
Path path = status.getPath();
SequenceFileIterator<IntPairWritable,DoubleWritable> iterator =
new SequenceFileIterator<IntPairWritable,DoubleWritable>(path, true, job);
try {
while (iterator.hasNext()) {
Pair<IntPairWritable,DoubleWritable> record = iterator.next();
if (record.getFirst().getFirst() == LOG_LIKELIHOOD_KEY) {
ll = record.getSecond().get();
break;
}
}
} finally {
Closeables.closeQuietly(iterator);
}
}
return ll;
}
private double runIterationSequential(Configuration conf, Path input, Path stateOut) throws IOException {
if (state == null) {
state = createState(conf);
}
if (trainingCorpus == null) {
Class<? extends Writable> keyClass = peekAtSequenceFileForKeyType(conf, input);
Collection<Pair<Writable, VectorWritable>> corpus = new LinkedList<Pair<Writable, VectorWritable>>();
for (FileStatus fileStatus : FileSystem.get(conf).globStatus(new Path(input, "part-*"))) {
Path inputPart = fileStatus.getPath();
SequenceFile.Reader reader = new SequenceFile.Reader(FileSystem.get(conf), inputPart, conf);
Writable key = ReflectionUtils.newInstance(keyClass, conf);
VectorWritable value = new VectorWritable();
while (reader.next(key, value)) {
Writable nextKey = ReflectionUtils.newInstance(keyClass, conf);
VectorWritable nextValue = new VectorWritable();
corpus.add(new Pair<Writable,VectorWritable>(key, value));
key = nextKey;
value = nextValue;
}
}
trainingCorpus = corpus;
}
if (inference == null) {
inference = new LDAInference(state);
}
LDAState newState = createState(conf, true);
double ll = 0.0;
for (Pair<Writable, VectorWritable> slice : trainingCorpus) {
LDAInference.InferredDocument doc;
Vector wordCounts = slice.getSecond().get();
try {
doc = inference.infer(wordCounts);
} catch (ArrayIndexOutOfBoundsException e1) {
throw new IllegalStateException(
"This is probably because the --numWords argument is set too small. \n"
+ "\tIt needs to be >= than the number of words (terms actually) in the corpus and can be \n"
+ "\tlarger if some storage inefficiency can be tolerated.", e1);
}
for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero(); iter.hasNext();) {
Vector.Element e = iter.next();
int w = e.index();
for (int k = 0; k < state.getNumTopics(); ++k) {
double vwUpdate = doc.phi(k, w) + Math.log(e.get());
newState.updateLogProbGivenTopic(w, k, vwUpdate); // update state.topicWordProbabilities[v,w]!
newState.updateLogTotals(k, vwUpdate);
}
ll += doc.getLogLikelihood();
}
}
newState.setLogLikelihood(ll);
writeState(conf, newState, stateOut);
state = newState;
return ll;
}
/**
* Run the job using supplied arguments
* @param input
* the directory pathname for input points
* @param stateIn
* the directory pathname for input state
* @param stateOut
* the directory pathname for output state
*/
private static double runIteration(Configuration conf,
Path input,
Path stateIn,
Path stateOut)
throws IOException, InterruptedException, ClassNotFoundException {
conf.set(STATE_IN_KEY, stateIn.toString());
Job job = new Job(conf, "LDA Driver running runIteration over stateIn: " + stateIn);
job.setOutputKeyClass(IntPairWritable.class);
job.setOutputValueClass(DoubleWritable.class);
FileInputFormat.addInputPaths(job, input.toString());
FileOutputFormat.setOutputPath(job, stateOut);
job.setMapperClass(LDAWordTopicMapper.class);
job.setReducerClass(LDAReducer.class);
job.setCombinerClass(LDAReducer.class);
job.setOutputFormatClass(SequenceFileOutputFormat.class);
job.setInputFormatClass(SequenceFileInputFormat.class);
job.setJarByClass(LDADriver.class);
if (!job.waitForCompletion(true)) {
throw new InterruptedException("LDA Iteration failed processing " + stateIn);
}
return findLL(stateOut, conf);
}
private static void computeDocumentTopicProbabilities(Configuration conf,
Path input,
Path stateIn,
Path outputPath,
int numTopics,
int numWords,
double topicSmoothing)
throws IOException, InterruptedException, ClassNotFoundException {
conf.set(STATE_IN_KEY, stateIn.toString());
conf.set(NUM_TOPICS_KEY, Integer.toString(numTopics));
conf.set(NUM_WORDS_KEY, Integer.toString(numWords));
conf.set(TOPIC_SMOOTHING_KEY, Double.toString(topicSmoothing));
Job job = new Job(conf, "LDA Driver computing p(topic|doc) for all docs/topics with stateIn: " + stateIn);
job.setOutputKeyClass(peekAtSequenceFileForKeyType(conf, input));
job.setOutputValueClass(VectorWritable.class);
FileInputFormat.addInputPaths(job, input.toString());
FileOutputFormat.setOutputPath(job, outputPath);
job.setMapperClass(LDADocumentTopicMapper.class);
job.setNumReduceTasks(0);
job.setOutputFormatClass(SequenceFileOutputFormat.class);
job.setInputFormatClass(SequenceFileInputFormat.class);
job.setJarByClass(LDADriver.class);
if (!job.waitForCompletion(true)) {
throw new InterruptedException("LDA failed to compute and output document topic probabilities with: "+ stateIn);
}
}
private void computeDocumentTopicProbabilitiesSequential(Configuration conf, Path input, Path outputPath)
throws IOException {
FileSystem fs = input.getFileSystem(conf);
Class<? extends Writable> keyClass = peekAtSequenceFileForKeyType(conf, input);
SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, outputPath, keyClass, VectorWritable.class);
try {
Writable key = ReflectionUtils.newInstance(keyClass, conf);
Writable vw = new VectorWritable();
for (Pair<Writable, VectorWritable> slice : trainingCorpus) {
Vector wordCounts = slice.getSecond().get();
try {
inference.infer(wordCounts);
} catch (ArrayIndexOutOfBoundsException e1) {
throw new IllegalStateException(
"This is probably because the --numWords argument is set too small. \n"
+ "\tIt needs to be >= than the number of words (terms actually) in the corpus and can be \n"
+ "\tlarger if some storage inefficiency can be tolerated.", e1);
}
writer.append(key, vw);
}
} finally {
Closeables.closeQuietly(writer);
}
}
private static Class<? extends Writable> peekAtSequenceFileForKeyType(Configuration conf, Path input) {
try {
SequenceFile.Reader reader = new SequenceFile.Reader(FileSystem.get(conf), input, conf);
return (Class<? extends Writable>) reader.getKeyClass();
} catch (IOException ioe) {
return Text.class;
}
}
}