package skywriting.examples.skyhout.linalg; import java.io.EOFException; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.DoubleWritable; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.serializer.Serialization; import org.apache.hadoop.io.serializer.WritableSerialization; import org.apache.mahout.math.VectorWritable; import skywriting.examples.skyhout.common.PartialHashOutputCollector; import skywriting.examples.skyhout.common.SkywritingTaskFileSystem; import uk.co.mrry.mercator.task.JarTaskLoader; import uk.co.mrry.mercator.task.Task; public class MatrixVectorMultiplyTask implements Task { @Override public void invoke(InputStream[] fis, OutputStream[] fos, String[] args) { try { Configuration conf = new Configuration(); conf.setClassLoader(JarTaskLoader.CLASSLOADER); conf.setClass("io.serializations", WritableSerialization.class, Serialization.class); SkywritingTaskFileSystem fs = new SkywritingTaskFileSystem(fis, fos, conf); // Input[0] is the matrix chunk; [1] is the vector. assert fs.numInputs() == 2; // Output[0] is the vector chunk. assert fs.numOutputs() == 1; // Read in the vector. SequenceFile.Reader vectorReader = new SequenceFile.Reader(fs, new Path("/in/1"), conf); Text dummyKey = new Text(); VectorWritable vector = new VectorWritable(); vectorReader.next(dummyKey, vector); vectorReader.close(); // Iterate over matrix chunk rows. SequenceFile.Reader matrixReader = new SequenceFile.Reader(fs, new Path("/in/0"), conf); IntWritable currentRowIndex = new IntWritable(); VectorWritable currentRow = new VectorWritable(); SequenceFile.Writer output = new SequenceFile.Writer(fs, conf, new Path("/out/0"), IntWritable.class, DoubleWritable.class); DoubleWritable dotProduct = new DoubleWritable(); while (true) { try { boolean hasMore = matrixReader.next(currentRowIndex, currentRow); if (!hasMore) break; } catch (EOFException eofe) { break; } dotProduct.set(currentRow.get().dot(vector.get())); output.append(currentRowIndex, dotProduct); } output.close(); } catch (IOException ioe) { throw new RuntimeException(ioe); } } }