package com.skp.experiment.cf.math.hadoop; import java.io.IOException; import java.util.Iterator; import java.util.Map; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.Mapper; import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; import org.apache.mahout.common.AbstractJob; import org.apache.mahout.math.RandomAccessSparseVector; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorWritable; public class PruneVectorWithThreasholdJob extends AbstractJob { public static final String THRESHOLD = PruneVectorWithThreasholdJob.class.getName() + ".threshold"; public static Job createPruneVectorWithThresholdJob(Configuration initConf, Path input, Path output) throws IOException { Job job = new Job(initConf, PruneVectorWithThreasholdJob.class.getName()); job.setInputFormatClass(SequenceFileInputFormat.class); job.setOutputFormatClass(SequenceFileOutputFormat.class); job.setMapperClass(PruneVectorWithThresholdMapper.class); job.setMapOutputKeyClass(IntWritable.class); job.setMapOutputValueClass(VectorWritable.class); job.setJarByClass(PruneVectorWithThreasholdJob.class); FileInputFormat.addInputPath(job, input); FileOutputFormat.setOutputPath(job, output); job.setNumReduceTasks(0); return job; } @Override public int run(String[] args) throws Exception { addInputOption(); addOutputOption(); Map<String, String> parsedArg = parseArguments(args); if (parsedArg == null) { return -1; } Job job = prepareJob(getInputPath(), getOutputPath(), SequenceFileInputFormat.class, PruneVectorWithThresholdMapper.class, IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class); job.waitForCompletion(true); return 0; } public static class PruneVectorWithThresholdMapper extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> { private static float threshold; @Override protected void setup(Context context) throws IOException, InterruptedException { threshold = context.getConfiguration().getFloat(THRESHOLD, 0); } @Override protected void map(IntWritable key, VectorWritable value, Context context) throws IOException, InterruptedException { Vector colVectors = value.get(); Iterator<Vector.Element> cols = colVectors.iterateNonZero(); Vector result = new RandomAccessSparseVector(colVectors.size(), colVectors.getNumNondefaultElements()); while (cols.hasNext()) { Vector.Element col = cols.next(); if (col.get() < threshold) { continue; } result.set(col.index(), col.get()); } context.write(key, new VectorWritable(result)); } } }