/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.clustering.lda; import java.io.IOException; import java.util.Arrays; import java.util.Iterator; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.DoubleWritable; import org.apache.hadoop.io.WritableComparable; import org.apache.hadoop.mapreduce.Mapper; import org.apache.mahout.common.IntPairWritable; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorWritable; /** * Runs inference on the input documents (which are sparse vectors of word counts) and outputs the sufficient * statistics for the word-topic assignments. */ public class LDAWordTopicMapper extends Mapper<WritableComparable<?>,VectorWritable,IntPairWritable,DoubleWritable> { private LDAState state; private LDAInference infer; @Override protected void map(WritableComparable<?> key, VectorWritable wordCountsWritable, Context context) throws IOException, InterruptedException { Vector wordCounts = wordCountsWritable.get(); LDAInference.InferredDocument doc; try { doc = infer.infer(wordCounts); } catch (ArrayIndexOutOfBoundsException e1) { throw new IllegalStateException( "This is probably because the --numWords argument is set too small. \n" + "\tIt needs to be >= than the number of words (terms actually) in the corpus and can be \n" + "\tlarger if some storage inefficiency can be tolerated.", e1); } double[] logTotals = new double[state.getNumTopics()]; Arrays.fill(logTotals, Double.NEGATIVE_INFINITY); // Output sufficient statistics for each word. == pseudo-log counts. DoubleWritable v = new DoubleWritable(); for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero(); iter.hasNext();) { Vector.Element e = iter.next(); int w = e.index(); for (int k = 0; k < state.getNumTopics(); ++k) { v.set(doc.phi(k, w) + Math.log(e.get())); IntPairWritable kw = new IntPairWritable(k, w); // output (topic, word)'s logProb contribution context.write(kw, v); logTotals[k] = LDAUtil.logSum(logTotals[k], v.get()); } } // Output the totals for the statistics. This is to make // normalizing a lot easier. for (int k = 0; k < state.getNumTopics(); ++k) { IntPairWritable kw = new IntPairWritable(k, LDADriver.TOPIC_SUM_KEY); v.set(logTotals[k]); assert !Double.isNaN(v.get()); context.write(kw, v); } IntPairWritable llk = new IntPairWritable(LDADriver.LOG_LIKELIHOOD_KEY, LDADriver.LOG_LIKELIHOOD_KEY); // Output log-likelihoods. v.set(doc.getLogLikelihood()); context.write(llk, v); } public void configure(LDAState myState) { this.state = myState; this.infer = new LDAInference(state); } public void configure(Configuration job) { LDAState myState = LDADriver.createState(job); configure(myState); } @Override protected void setup(Context context) { configure(context.getConfiguration()); } }