package com.skp.experiment.common;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mrunit.mapreduce.MapDriver;
import org.apache.hadoop.mrunit.mapreduce.MapReduceDriver;
import org.apache.hadoop.mrunit.mapreduce.ReduceDriver;
import org.apache.hadoop.mrunit.types.Pair;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.junit.Before;
import org.junit.Test;
import com.skp.experiment.common.Text2DistributedRowMatrixJob.Text2DistributedRowMatrixMapper;
import com.skp.experiment.common.Text2DistributedRowMatrixJob.Text2DistributedRowMatrixReducer;
public class Text2DistributedRowMatrixJobTest extends MahoutTestCase{
Text2DistributedRowMatrixMapper mapper = new Text2DistributedRowMatrixMapper();
Text2DistributedRowMatrixReducer reducer = new Text2DistributedRowMatrixReducer();
@SuppressWarnings("rawtypes")
MapDriver<LongWritable, Text, WritableComparable, VectorWritable> mapDriver;
@SuppressWarnings("rawtypes")
ReduceDriver<WritableComparable, VectorWritable, WritableComparable, VectorWritable> reduceDriver;
MapReduceDriver<LongWritable, Text, IntWritable, VectorWritable, IntWritable, VectorWritable> mapReduceDriver;
@SuppressWarnings("rawtypes")
@Before
public void setUp() throws Exception {
super.setUp();
mapDriver = new MapDriver<LongWritable, Text, WritableComparable, VectorWritable>();
reduceDriver = new ReduceDriver<WritableComparable, VectorWritable, WritableComparable, VectorWritable>();
mapReduceDriver = new MapReduceDriver<LongWritable, Text, IntWritable, VectorWritable, IntWritable, VectorWritable>();
mapDriver.setMapper(mapper);
reduceDriver.setReducer(reducer);
}
@SuppressWarnings("rawtypes")
@Test
public void testMapper() throws IOException {
Configuration conf = new Configuration();
conf.setInt(Text2DistributedRowMatrixJob.ROW_IDX_KEY, 0);
conf.setInt(Text2DistributedRowMatrixJob.COL_IDX_KEY, 1);
conf.setInt(Text2DistributedRowMatrixJob.VALUE_IDX_KEY, 2);
conf.setBoolean(Text2DistributedRowMatrixJob.SEQUENTIAL, false);
conf.set(Text2DistributedRowMatrixJob.OUT_KEY_TYPE, "int");
mapDriver.setConfiguration(conf);
mapDriver.withInput(new LongWritable(), new Text("2,3,4.3"));
List<Pair<WritableComparable, VectorWritable>> outputs = mapDriver.run();
assertTrue(outputs.get(0).getFirst().getClass().equals(IntWritable.class));
assertTrue(((IntWritable)outputs.get(0).getFirst()).get() == 2);
assertTrue(outputs.get(0).getSecond().get().isSequentialAccess() == false);
assertTrue(Math.abs(outputs.get(0).getSecond().get().get(3) - 4.3) < EPSILON);
}
@SuppressWarnings("rawtypes")
@Test
public void testMapperConfig() throws IOException {
Configuration conf = new Configuration();
conf.setInt(Text2DistributedRowMatrixJob.ROW_IDX_KEY, 1);
conf.setInt(Text2DistributedRowMatrixJob.COL_IDX_KEY, 0);
conf.setInt(Text2DistributedRowMatrixJob.VALUE_IDX_KEY, 2);
conf.setBoolean(Text2DistributedRowMatrixJob.SEQUENTIAL, true);
conf.set(Text2DistributedRowMatrixJob.OUT_KEY_TYPE, "text");
mapDriver.setConfiguration(conf);
mapDriver.withInput(new LongWritable(), new Text("2,3,4.3"));
List<Pair<WritableComparable, VectorWritable>> outputs = mapDriver.run();
assertTrue(outputs.get(0).getFirst().getClass().equals(Text.class));
assertTrue(((Text)outputs.get(0).getFirst()).toString().equals("3"));
assertTrue(outputs.get(0).getSecond().get().isSequentialAccess());
assertTrue(Math.abs(outputs.get(0).getSecond().get().get(2) - 4.3) < EPSILON);
}
@SuppressWarnings("rawtypes")
@Test
public void testReducer() throws IOException {
List<VectorWritable> vectors = new ArrayList<VectorWritable>();
Vector v1 = new RandomAccessSparseVector(Integer.MAX_VALUE);
v1.set(0, 1.0); v1.set(1, 1.0);
Vector v2 = new RandomAccessSparseVector(Integer.MAX_VALUE);
v2.set(2, 1.0); v2.set(3, 1.0);
vectors.add(new VectorWritable(v1));
vectors.add(new VectorWritable(v2));
reduceDriver.withInput(new IntWritable(1), vectors);
List<Pair<WritableComparable, VectorWritable>> outputs = reduceDriver.run();
assertTrue(outputs.get(0).getFirst().getClass().equals(IntWritable.class));
assertTrue(((IntWritable)outputs.get(0).getFirst()).get() == 1);
Vector outputVector = outputs.get(0).getSecond().get();
assertTrue(outputVector.getNumNondefaultElements() == 4);
Iterator<Vector.Element> cols = outputVector.iterateNonZero();
while (cols.hasNext()) {
Vector.Element e = cols.next();
assertTrue(Math.abs(e.get() - 1.0) < EPSILON);
}
}
}