package com.skp.experiment.cf.als.hadoop; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.Mapper; import org.apache.hadoop.mapreduce.Reducer; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat; import org.apache.hadoop.util.ToolRunner; import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils; import org.apache.mahout.common.AbstractJob; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.skp.experiment.cf.evaluate.hadoop.EvaluatorUtil; public class FilterInputJob extends AbstractJob { private static final Logger log = LoggerFactory.getLogger(FilterInputJob.class); private static final String INVALID_ITEM_PATH = FilterInputJob.class.getName() + ".invalidItemPath"; private static final String MIN_NON_DEFAULT_ELEMENTS_NUM = FilterInputJob.class.getName() + ".minNonDefaultElementsNum"; private static final String[] PROHIBIT_PREFIX_ON_ITEM_IDS = new String[]{"R", "H"}; private static final String DELIMETER = ","; private static final String ITEM_INDEX = FilterInputJob.class.getName() + ".itemIndex"; private static final String USER_INDEX = FilterInputJob.class.getName() + ".userIndex"; /* private static final String TRAINING_PERCENTAGE = FilterInputJob.class.getName() + ".trainingPercentage"; private static final String PROBE_PERCENTAGE = FilterInputJob.class.getName() + ".probePercentage"; private static final String TRAINING_PATH = FilterInputJob.class.getName() + ".trainingPath"; private static final String PROBE_PATH = FilterInputJob.class.getName() + ".probePath"; private static final double DEFAULT_TRAINING_PERCENTAGE = 0.9; private static final double DEFAULT_PROBE_PERCENTAGE = 0.1; */ public static void main(String[] args) throws Exception { ToolRunner.run(new FilterInputJob(), args); } @Override public int run(String[] args) throws Exception { addInputOption(); addOutputOption(); addOption("invalidItemPath", null, "invalid item path.", null); addOption("minNonDefaultElementsNum", "minK", "minimum number of non default elements in result vector.", String.valueOf(0)); addOption("itemIndex", "itemIdx", "item id index", String.valueOf(1)); addOption("userIndex", "userIdx", "user id index", String.valueOf(0)); /* addOption("trainingPercentage", "t", "percentage of the data to use as training set", String.valueOf(DEFAULT_TRAINING_PERCENTAGE)); addOption("probePercentage", "p", "percentage of the data to use as probe set", String.valueOf(DEFAULT_PROBE_PERCENTAGE)); */ Map<String, String> parsedArgs = parseArguments(args); if (parsedArgs == null) { return -1; } Job filterJob = prepareJob(getInputPath(), getOutputPath(), TextInputFormat.class, FilterInvalidItemMapper.class, Text.class, Text.class, FilterInvalidItemReducer.class, NullWritable.class, Text.class, TextOutputFormat.class ); //Path trainingPath = new Path(getOutputPath().getParent(), "trainingSet"); //Path probePath = new Path(getOutputPath().getParent(), "probeSet"); filterJob.getConfiguration().set(INVALID_ITEM_PATH, getOption("invalidItemPath")); filterJob.getConfiguration().set(MIN_NON_DEFAULT_ELEMENTS_NUM, getOption("minNonDefaultElementsNum")); filterJob.getConfiguration().setInt(ITEM_INDEX, Integer.parseInt(getOption("itemIndex"))); filterJob.getConfiguration().setInt(USER_INDEX, Integer.parseInt(getOption("userIndex"))); /* filterJob.getConfiguration().set(TRAINING_PATH, trainingPath.toString()); filterJob.getConfiguration().set(PROBE_PATH, probePath.toString()); filterJob.getConfiguration().setFloat(TRAINING_PERCENTAGE, Float.parseFloat(getOption("trainingPercentage"))); filterJob.getConfiguration().setFloat(PROBE_PERCENTAGE, Float.parseFloat(getOption("probePercentage"))); */ filterJob.waitForCompletion(true); return 0; } private static class FilterInvalidItemMapper extends Mapper<LongWritable, Text, Text, Text> { private static Map<String, String> invalidItems = null; private static Text outKey = new Text(); private static Text outValue = new Text(); private static Integer itemIndex = 0; private static Integer userIndex = 0; @Override protected void map(LongWritable offset, Text line, Context ctx) throws IOException, InterruptedException { String[] tokens = TasteHadoopUtils.splitPrefTokens(line.toString()); String userID = tokens[userIndex]; String itemID = tokens[itemIndex]; if (invalidItems == null || invalidItems.containsKey(itemID) == false) { boolean isProhibit = false; for (String prefix : PROHIBIT_PREFIX_ON_ITEM_IDS) { if (itemID.startsWith(prefix)) { log.info("Prohibit ItemID: {}", itemID); isProhibit = true; break; } } if (!isProhibit) { outKey.set(userID); outValue.set(line); ctx.write(outKey, outValue); } } } @Override protected void setup(Context context) throws IOException, InterruptedException { Configuration conf = context.getConfiguration(); itemIndex = conf.getInt(INVALID_ITEM_PATH, 1); userIndex = conf.getInt(USER_INDEX, 0); if (conf.get(INVALID_ITEM_PATH) != null) { invalidItems = ALSMatrixUtil.fetchTextFiles(context, new Path(conf.get(INVALID_ITEM_PATH)), DELIMETER, Arrays.asList(itemIndex), Arrays.asList(itemIndex)); context.setStatus("total: " + invalidItems.size()); } } } private static class FilterInvalidItemReducer extends Reducer<Text, Text, NullWritable, Text> { private static int minNonDefaultElementsNum = 0; private static Text outValue = new Text(); /* private Random random; private float trainingBound; private float probeBound; private Path trainingPath; private Path probePath; //private SequenceFile.Writer trainingWriter = null; //private SequenceFile.Writer probeWriter = null; private FSDataOutputStream trainingOut = null; private FSDataOutputStream probeOut = null; private String getPartNum(Context context) { String taskId = context.getConfiguration().get("mapred.task.id"); String[] parts = taskId.split("_"); return "part-" + parts[parts.length - 2] + "-" + parts[parts.length - 1]; } */ protected void setup(Context context) throws IOException, InterruptedException { Configuration conf = context.getConfiguration(); FileSystem fs = FileSystem.get(conf); minNonDefaultElementsNum = conf.getInt(MIN_NON_DEFAULT_ELEMENTS_NUM, -1); /* random = RandomUtils.getRandom(); trainingBound = conf.getFloat(TRAINING_PERCENTAGE, (float)DEFAULT_TRAINING_PERCENTAGE); probeBound = trainingBound + conf.getFloat(PROBE_PERCENTAGE, (float)DEFAULT_PROBE_PERCENTAGE); trainingPath = new Path(conf.get(TRAINING_PATH), getPartNum(context)); probePath = new Path(conf.get(PROBE_PATH), getPartNum(context)); trainingOut = fs.create(trainingPath); probeOut = fs.create(probePath); */ } @Override protected void reduce(Text user, Iterable<Text> lines, Context context) throws IOException, InterruptedException { List<String> aggregated = new ArrayList<String>(); for (Text line : lines) { aggregated.add(line.toString()); } if (aggregated.size() < minNonDefaultElementsNum) { return; } for (String s : aggregated) { outValue.set(s); /* double randomValue = random.nextDouble(); if (randomValue <= trainingBound) { trainingOut.writeUTF(s); //trainingWriter.append(NullWritable.get(), outValue); } else if (randomValue <= probeBound) { probeOut.writeUTF(s); //probeWriter.append(NullWritable.get(), outValue); } */ context.write(NullWritable.get(), outValue); } } } }