/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.skp.experiment.cf.als.hadoop; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IOUtils; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.Mapper; import org.apache.hadoop.mapreduce.Reducer; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat; import org.apache.hadoop.util.ToolRunner; import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils; import org.apache.mahout.common.AbstractJob; import org.apache.mahout.common.Pair; import com.skp.experiment.common.HadoopClusterUtil; import com.skp.experiment.common.parameter.DefaultOptionCreator; /** * <p>Split a recommendation dataset into a training and a test set</p> * <p>randomly pick P population from user`s * * <p>Command line arguments specific to this class are:</p> * * <ol> * <li>--input (path): Directory containing one or more text files with the dataset</li> * <li>--output (path): path where output should go</li> * <li>--trainingPercentage (double): percentage of the data to use as training set (optional, default 0.9)</li> * <li>--probePercentage (double): percentage of the data to use as probe set (optional, default 0.1)</li> * </ol> */ public class KFoldDatasetSplitter extends AbstractJob { private static final String KEY_INDEX = KFoldDatasetSplitter.class.getName() + ".keyIndex"; private static final String K_FOLD = DatasetSplitter.class.getName() + ".kFold"; private static final String PROBE_SET = DatasetSplitter.class.getName() + ".probeSet"; private static final String TRAIN_SET = DatasetSplitter.class.getName() + ".trainSet"; private static final int DEFAULT_K_FOLD = 4; public static void main(String[] args) throws Exception { ToolRunner.run(new KFoldDatasetSplitter(), args); } @Override public int run(String[] args) throws Exception { addInputOption(); addOutputOption(); addOption("kfold", "k", "number of fold for cross validation.", String.valueOf(DEFAULT_K_FOLD)); Map<String, String> parsedArgs = parseArguments(args); if (parsedArgs == null) { return -1; } Path markedPrefs = new Path(getOption("tempDir"), "markedPreferences"); Path trainingSetPath = new Path(getOutputPath(), "trainingSet"); Path probeSetPath = new Path(getOutputPath(), "probeSet"); int kFold = Integer.parseInt(getOption("kfold")); /** step0. build trainingSet/probeSet pair as (1-P, P) probability for each K fold. */ Job markPreferences = prepareJob(getInputPath(), markedPrefs, TextInputFormat.class, MarkPreferencesMapper.class, Text.class, Text.class, MarkPreferencesReducer.class, NullWritable.class, Text.class, TextOutputFormat.class); markPreferences.getConfiguration().setInt(KEY_INDEX, Integer.parseInt(getOption("keyIndex"))); markPreferences.getConfiguration().setInt(K_FOLD, kFold); markPreferences.getConfiguration().set(TRAIN_SET, trainingSetPath.toString()); markPreferences.getConfiguration().set(PROBE_SET, probeSetPath.toString()); markPreferences.waitForCompletion(true); return 0; } static class MarkPreferencesMapper extends Mapper<LongWritable,Text,Text,Text> { private static Text outKey = new Text(); private static int keyIndex = 0; @Override protected void setup(Context ctx) throws IOException, InterruptedException { keyIndex = ctx.getConfiguration().getInt(KEY_INDEX, 0); } @Override protected void map(LongWritable key, Text text, Context ctx) throws IOException, InterruptedException { outKey.set(TasteHadoopUtils.splitPrefTokens(text.toString())[keyIndex]); ctx.write(outKey, text); } } static class MarkPreferencesReducer extends Reducer<Text, Text, Text, Text> { private int kfold; private static List<FSDataOutputStream> trainStreams = new ArrayList<FSDataOutputStream>(); private static List<FSDataOutputStream> probeStreams = new ArrayList<FSDataOutputStream>(); private static FileSystem fs; @Override protected void setup(Context ctx) throws IOException, InterruptedException { kfold = ctx.getConfiguration().getInt(K_FOLD, DEFAULT_K_FOLD); fs = FileSystem.get(ctx.getConfiguration()); String taskId = HadoopClusterUtil.getAttemptId(ctx.getConfiguration()); /** populate kfold number of stream use taskId to make sure no collison on file name */ for (int i = 0; i < kfold; i++) { Path curProbeSetPath = new Path(ctx.getConfiguration().get(PROBE_SET), i + "/" + taskId); Path curTrainSetPath = new Path(ctx.getConfiguration().get(TRAIN_SET), i + "/" + taskId); trainStreams.add(fs.create(curTrainSetPath, true)); probeStreams.add(fs.create(curProbeSetPath, true)); } } /** * delete all open Writer. */ @Override protected void cleanup(Context context) throws IOException, InterruptedException { for (FSDataOutputStream s : trainStreams) { IOUtils.closeStream(s); } for (FSDataOutputStream s : probeStreams) { IOUtils.closeStream(s); } } /** * list of item:rating per this user */ @Override protected void reduce(Text key, Iterable<Text> values, Context ctx) throws IOException, InterruptedException { ArrayList<String> list = new ArrayList<String>(); for (Text value : values) { list.add(value.toString()); } // random shuffle // space complexity: O(|# items for this user| x 2). KFoldCrossValidationUtils.randomSuffleInPlace(list); for (int i = 0; i < kfold; i++) { /** get current fold as probeSet */ Pair<List<String>, List<String>> trainingAndProbe = KFoldCrossValidationUtils.splitNth(list, kfold, i); List<String> trains = trainingAndProbe.getFirst(); List<String> probes = trainingAndProbe.getSecond(); /** flush out current k'th fold training/probeset pair into filestream */ for (int t = 0; t < trains.size(); t++) { trainStreams.get(i).writeBytes(trains.get(t) + DefaultOptionCreator.NEWLINE); } for (int p = 0; p < probes.size(); p++) { probeStreams.get(i).writeBytes(probes.get(p) + DefaultOptionCreator.NEWLINE); } } } } }