/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.math.hadoop.stochasticsvd;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Iterator;
import org.apache.commons.lang.Validate;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile.CompressionType;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.compress.DefaultCodec;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.lib.MultipleOutputs;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.mahout.common.IOUtils;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterator;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.hadoop.stochasticsvd.qr.QRLastStep;
/**
* Bt job. For details, see working notes in MAHOUT-376.
* <P>
*
* Uses hadoop deprecated API wherever new api has not been updated
* (MAHOUT-593), hence @SuppressWarning("deprecation").
* <P>
*
* This job outputs either Bt in its standard output, or upper triangular
* matrices representing BBt partial sums if that's requested . If the latter
* mode is enabled, then we accumulate BBt outer product sums in upper
* triangular accumulator and output it at the end of the job, thus saving space
* and BBt job.
* <P>
*
* This job also outputs Q and Bt and optionally BBt. Bt is output to standard
* job output (part-*) and Q and BBt use named multiple outputs.
*
* <P>
*
*/
@SuppressWarnings("deprecation")
public final class BtJob {
public static final String OUTPUT_Q = "Q";
public static final String OUTPUT_BT = "part";
public static final String OUTPUT_BBT = "bbt";
public static final String PROP_QJOB_PATH = "ssvd.QJob.path";
public static final String PROP_OUPTUT_BBT_PRODUCTS =
"ssvd.BtJob.outputBBtProducts";
public static final String PROP_OUTER_PROD_BLOCK_HEIGHT =
"ssvd.outerProdBlockHeight";
public static final String PROP_RHAT_BROADCAST = "ssvd.rhat.broadcast";
static final double SPARSE_ZEROS_PCT_THRESHOLD = 0.1;
private BtJob() {
}
public static class BtMapper extends
Mapper<Writable, VectorWritable, LongWritable, SparseRowBlockWritable> {
private QRLastStep qr;
private final Deque<Closeable> closeables = new ArrayDeque<Closeable>();
private int blockNum;
private MultipleOutputs outputs;
private final VectorWritable qRowValue = new VectorWritable();
private Vector btRow;
private SparseRowBlockAccumulator btCollector;
private Context mapContext;
@Override
protected void cleanup(Context context) throws IOException,
InterruptedException {
IOUtils.close(closeables);
}
@SuppressWarnings("unchecked")
private void outputQRow(Writable key, Writable value) throws IOException {
outputs.getCollector(OUTPUT_Q, null).collect(key, value);
}
/**
* We maintain A and QtHat inputs partitioned the same way, so we
* essentially are performing map-side merge here of A and QtHats except
* QtHat is stored not row-wise but block-wise.
*/
@Override
protected void map(Writable key, VectorWritable value, Context context)
throws IOException, InterruptedException {
mapContext = context;
// output Bt outer products
Vector aRow = value.get();
Vector qRow = qr.next();
int kp = qRow.size();
qRowValue.set(qRow);
// make sure Qs are inheriting A row labels.
outputQRow(key, qRowValue);
if (btRow == null) {
btRow = new DenseVector(kp);
}
if (!aRow.isDense()) {
for (Iterator<Vector.Element> iter = aRow.iterateNonZero(); iter.hasNext();) {
Vector.Element el = iter.next();
double mul = el.get();
for (int j = 0; j < kp; j++) {
btRow.setQuick(j, mul * qRow.getQuick(j));
}
btCollector.collect((long) el.index(), btRow);
}
} else {
int n = aRow.size();
for (int i = 0; i < n; i++) {
double mul = aRow.getQuick(i);
for (int j = 0; j < kp; j++) {
btRow.setQuick(j, mul * qRow.getQuick(j));
}
btCollector.collect((long) i, btRow);
}
}
}
@Override
protected void setup(Context context) throws IOException,
InterruptedException {
super.setup(context);
Path qJobPath = new Path(context.getConfiguration().get(PROP_QJOB_PATH));
/*
* actually this is kind of dangerous because this routine thinks we need
* to create file name for our current job and this will use -m- so it's
* just serendipity we are calling it from the mapper too as the QJob did.
*/
Path qInputPath =
new Path(qJobPath, FileOutputFormat.getUniqueFile(context,
QJob.OUTPUT_QHAT,
""));
blockNum = context.getTaskAttemptID().getTaskID().getId();
SequenceFileValueIterator<DenseBlockWritable> qhatInput =
new SequenceFileValueIterator<DenseBlockWritable>(qInputPath,
true,
context.getConfiguration());
closeables.addFirst(qhatInput);
/*
* read all r files _in order of task ids_, i.e. partitions (aka group
* nums).
*
* Note: if broadcast option is used, this comes from distributed cache
* files rather than hdfs path.
*/
SequenceFileDirValueIterator<VectorWritable> rhatInput;
boolean distributedRHat =
context.getConfiguration().get(PROP_RHAT_BROADCAST) != null;
if (distributedRHat) {
Path[] rFiles =
DistributedCache.getLocalCacheFiles(context.getConfiguration());
Validate.notNull(rFiles,
"no RHat files in distributed cache job definition");
Configuration conf = new Configuration();
conf.set("fs.default.name", "file:///");
rhatInput =
new SequenceFileDirValueIterator<VectorWritable>(rFiles,
SSVDSolver.PARTITION_COMPARATOR,
true,
conf);
} else {
Path rPath = new Path(qJobPath, QJob.OUTPUT_RHAT + "-*");
rhatInput =
new SequenceFileDirValueIterator<VectorWritable>(rPath,
PathType.GLOB,
null,
SSVDSolver.PARTITION_COMPARATOR,
true,
context.getConfiguration());
}
Validate.isTrue(rhatInput.hasNext(), "Empty R-hat input!");
closeables.addFirst(rhatInput);
outputs = new MultipleOutputs(new JobConf(context.getConfiguration()));
closeables.addFirst(new IOUtils.MultipleOutputsCloseableAdapter(outputs));
qr = new QRLastStep(qhatInput, rhatInput, blockNum);
closeables.addFirst(qr);
/*
* it's so happens that current QRLastStep's implementation preloads R
* sequence into memory in the constructor so it's ok to close rhat input
* now.
*/
if (!rhatInput.hasNext()) {
closeables.remove(rhatInput);
rhatInput.close();
}
OutputCollector<LongWritable, SparseRowBlockWritable> btBlockCollector =
new OutputCollector<LongWritable, SparseRowBlockWritable>() {
@Override
public void collect(LongWritable blockKey,
SparseRowBlockWritable block) throws IOException {
try {
mapContext.write(blockKey, block);
} catch (InterruptedException exc) {
throw new IOException("Interrupted.", exc);
}
}
};
btCollector =
new SparseRowBlockAccumulator(context.getConfiguration()
.getInt(PROP_OUTER_PROD_BLOCK_HEIGHT,
-1),
btBlockCollector);
closeables.addFirst(btCollector);
}
}
public static class OuterProductCombiner
extends
Reducer<Writable, SparseRowBlockWritable, Writable, SparseRowBlockWritable> {
protected final SparseRowBlockWritable accum = new SparseRowBlockWritable();
protected final Deque<Closeable> closeables = new ArrayDeque<Closeable>();
protected int blockHeight;
@Override
protected void setup(Context context) throws IOException,
InterruptedException {
blockHeight =
context.getConfiguration().getInt(PROP_OUTER_PROD_BLOCK_HEIGHT, -1);
}
@Override
protected void reduce(Writable key,
Iterable<SparseRowBlockWritable> values,
Context context) throws IOException,
InterruptedException {
for (SparseRowBlockWritable bw : values) {
accum.plusBlock(bw);
}
context.write(key, accum);
accum.clear();
}
@Override
protected void cleanup(Context context) throws IOException,
InterruptedException {
IOUtils.close(closeables);
}
}
public static class OuterProductReducer
extends
Reducer<LongWritable, SparseRowBlockWritable, IntWritable, VectorWritable> {
protected final SparseRowBlockWritable accum = new SparseRowBlockWritable();
protected final Deque<Closeable> closeables = new ArrayDeque<Closeable>();
protected int blockHeight;
private boolean outputBBt;
private UpperTriangular mBBt;
private MultipleOutputs outputs;
private final IntWritable btKey = new IntWritable();
private final VectorWritable btValue = new VectorWritable();
@Override
protected void setup(Context context) throws IOException,
InterruptedException {
blockHeight =
context.getConfiguration().getInt(PROP_OUTER_PROD_BLOCK_HEIGHT, -1);
outputBBt =
context.getConfiguration().getBoolean(PROP_OUPTUT_BBT_PRODUCTS, false);
if (outputBBt) {
int k = context.getConfiguration().getInt(QJob.PROP_K, -1);
int p = context.getConfiguration().getInt(QJob.PROP_P, -1);
Validate.isTrue(k > 0, "invalid k parameter");
Validate.isTrue(p >= 0, "invalid p parameter");
mBBt = new UpperTriangular(k + p);
outputs = new MultipleOutputs(new JobConf(context.getConfiguration()));
closeables.addFirst(new IOUtils.MultipleOutputsCloseableAdapter(outputs));
}
}
@Override
protected void reduce(LongWritable key,
Iterable<SparseRowBlockWritable> values,
Context context) throws IOException,
InterruptedException {
accum.clear();
for (SparseRowBlockWritable bw : values) {
accum.plusBlock(bw);
}
/*
* at this point, sum of rows should be in accum, so we just generate
* outer self product of it and add to BBt accumulator.
*/
for (int k = 0; k < accum.getNumRows(); k++) {
Vector btRow = accum.getRows()[k];
btKey.set((int) (key.get() * blockHeight + accum.getRowIndices()[k]));
btValue.set(btRow);
context.write(btKey, btValue);
if (outputBBt) {
int kp = mBBt.numRows();
// accumulate partial BBt sum
for (int i = 0; i < kp; i++) {
double vi = btRow.get(i);
if (vi != 0.0) {
for (int j = i; j < kp; j++) {
double vj = btRow.get(j);
if (vj != 0.0) {
mBBt.setQuick(i, j, mBBt.getQuick(i, j) + vi * vj);
}
}
}
}
}
}
}
@Override
protected void cleanup(Context context) throws IOException,
InterruptedException {
// if we output BBt instead of Bt then we need to do it.
try {
if (outputBBt) {
@SuppressWarnings("unchecked")
OutputCollector<Writable, Writable> collector =
outputs.getCollector(OUTPUT_BBT, null);
collector.collect(new IntWritable(),
new VectorWritable(new DenseVector(mBBt.getData())));
}
} finally {
IOUtils.close(closeables);
}
}
}
public static void run(Configuration conf,
Path[] inputPathA,
Path inputPathQJob,
Path outputPath,
int minSplitSize,
int k,
int p,
int btBlockHeight,
int numReduceTasks,
boolean broadcast,
Class<? extends Writable> labelClass,
boolean outputBBtProducts)
throws ClassNotFoundException, InterruptedException, IOException {
JobConf oldApiJob = new JobConf(conf);
MultipleOutputs.addNamedOutput(oldApiJob,
OUTPUT_Q,
org.apache.hadoop.mapred.SequenceFileOutputFormat.class,
labelClass,
VectorWritable.class);
if (outputBBtProducts) {
MultipleOutputs.addNamedOutput(oldApiJob,
OUTPUT_BBT,
org.apache.hadoop.mapred.SequenceFileOutputFormat.class,
IntWritable.class,
VectorWritable.class);
}
/*
* HACK: we use old api multiple outputs since they are not available in the
* new api of either 0.20.2 or 0.20.203 but wrap it into a new api job so we
* can use new api interfaces.
*/
Job job = new Job(oldApiJob);
job.setJobName("Bt-job");
job.setJarByClass(BtJob.class);
job.setInputFormatClass(SequenceFileInputFormat.class);
job.setOutputFormatClass(SequenceFileOutputFormat.class);
FileInputFormat.setInputPaths(job, inputPathA);
if (minSplitSize > 0) {
FileInputFormat.setMinInputSplitSize(job, minSplitSize);
}
FileOutputFormat.setOutputPath(job, outputPath);
// WARN: tight hadoop integration here:
job.getConfiguration().set("mapreduce.output.basename", OUTPUT_BT);
FileOutputFormat.setOutputCompressorClass(job, DefaultCodec.class);
SequenceFileOutputFormat.setOutputCompressionType(job,
CompressionType.BLOCK);
job.setMapOutputKeyClass(LongWritable.class);
job.setMapOutputValueClass(SparseRowBlockWritable.class);
job.setOutputKeyClass(IntWritable.class);
job.setOutputValueClass(VectorWritable.class);
job.setMapperClass(BtMapper.class);
job.setCombinerClass(OuterProductCombiner.class);
job.setReducerClass(OuterProductReducer.class);
job.getConfiguration().setInt(QJob.PROP_K, k);
job.getConfiguration().setInt(QJob.PROP_P, p);
job.getConfiguration().set(PROP_QJOB_PATH, inputPathQJob.toString());
job.getConfiguration().setBoolean(PROP_OUPTUT_BBT_PRODUCTS,
outputBBtProducts);
job.getConfiguration().setInt(PROP_OUTER_PROD_BLOCK_HEIGHT, btBlockHeight);
job.setNumReduceTasks(numReduceTasks);
/*
* we can broadhast Rhat files since all of them are reuqired by each job,
* but not Q files which correspond to splits of A (so each split of A will
* require only particular Q file, each time different one).
*/
if (broadcast) {
job.getConfiguration().set(PROP_RHAT_BROADCAST, "y");
FileSystem fs = FileSystem.get(conf);
FileStatus[] fstats =
fs.globStatus(new Path(inputPathQJob, QJob.OUTPUT_RHAT + "-*"));
if (fstats != null) {
for (FileStatus fstat : fstats) {
/*
* new api is not enabled yet in our dependencies at this time, still
* using deprecated one
*/
DistributedCache.addCacheFile(fstat.getPath().toUri(),
job.getConfiguration());
}
}
}
job.submit();
job.waitForCompletion(false);
if (!job.isSuccessful()) {
throw new IOException("Bt job unsuccessful.");
}
}
}