/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.utils; import java.io.IOException; import java.util.Random; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.WritableComparable; import org.apache.hadoop.io.WritableComparator; import org.apache.hadoop.mapred.JobConf; import org.apache.hadoop.mapred.OutputCollector; import org.apache.hadoop.mapred.lib.MultipleOutputs; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.Mapper; import org.apache.hadoop.mapreduce.Reducer; import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; import org.apache.mahout.common.Pair; import org.apache.mahout.common.RandomUtils; import org.apache.mahout.common.iterator.sequencefile.PathFilters; import org.apache.mahout.common.iterator.sequencefile.PathType; import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator; @SuppressWarnings("deprecation") /** * Class which implements a map reduce version of SplitInput. * This class takes a SequenceFile input, e.g. a set of training data * for a learning algorithm, downsamples it, applies a random * permutation and splits it into test and training sets */ public final class SplitInputJob { private static final String DOWNSAMPLING_FACTOR = "SplitInputJob.downsamplingFactor"; private static final String RANDOM_SELECTION_PCT = "SplitInputJob.randomSelectionPct"; private static final String TRAINING_TAG = "training"; private static final String TEST_TAG = "test"; private SplitInputJob() { } /** * Run job to downsample, randomly permute and split data into test and * training sets. This job takes a SequenceFile as input and outputs two * SequenceFiles test-r-00000 and training-r-00000 which contain the test and * training sets respectively * * @param initialConf * @param inputPath * path to input data SequenceFile * @param outputPath * path for output data SequenceFiles * @param keepPct * percentage of key value pairs in input to keep. The rest are * discarded * @param randomSelectionPercent * percentage of key value pairs to allocate to test set. Remainder * are allocated to training set */ @SuppressWarnings("rawtypes") public static void run(Configuration initialConf, Path inputPath, Path outputPath, int keepPct, float randomSelectionPercent) throws IOException, ClassNotFoundException, InterruptedException { int downsamplingFactor = (int) (100.0 / keepPct); initialConf.setInt(DOWNSAMPLING_FACTOR, downsamplingFactor); initialConf.setFloat(RANDOM_SELECTION_PCT, randomSelectionPercent); // Determine class of keys and values FileSystem fs = FileSystem.get(initialConf); SequenceFileDirIterator<? extends WritableComparable, Writable> iterator = new SequenceFileDirIterator<WritableComparable, Writable>(inputPath, PathType.LIST, PathFilters.partFilter(), null, false, fs.getConf()); Class<? extends WritableComparable> keyClass; Class<? extends Writable> valueClass; if (iterator.hasNext()) { Pair<? extends WritableComparable, Writable> pair = iterator.next(); keyClass = pair.getFirst().getClass(); valueClass = pair.getSecond().getClass(); } else { throw new IllegalStateException("Couldn't determine class of the input values"); } // Use old API for multiple outputs JobConf oldApiJob = new JobConf(initialConf); MultipleOutputs.addNamedOutput(oldApiJob, TRAINING_TAG, org.apache.hadoop.mapred.SequenceFileOutputFormat.class, keyClass, valueClass); MultipleOutputs.addNamedOutput(oldApiJob, TEST_TAG, org.apache.hadoop.mapred.SequenceFileOutputFormat.class, keyClass, valueClass); // Setup job with new API Job job = new Job(oldApiJob); FileInputFormat.addInputPath(job, inputPath); FileOutputFormat.setOutputPath(job, outputPath); job.setNumReduceTasks(1); job.setInputFormatClass(SequenceFileInputFormat.class); job.setOutputFormatClass(SequenceFileOutputFormat.class); job.setMapperClass(SplitInputMapper.class); job.setReducerClass(SplitInputReducer.class); job.setSortComparatorClass(SplitInputComparator.class); job.setOutputKeyClass(keyClass); job.setOutputValueClass(valueClass); job.submit(); job.waitForCompletion(true); } /** * Mapper which downsamples the input by downsamplingFactor */ public static class SplitInputMapper extends Mapper<WritableComparable<?>, Writable, WritableComparable<?>, Writable> { private int downsamplingFactor; @Override public void setup(Context context) { downsamplingFactor = context.getConfiguration().getInt(DOWNSAMPLING_FACTOR, 1); } /** * Only run map() for one out of every downsampleFactor inputs */ @Override public void run(Context context) throws IOException, InterruptedException { setup(context); for (int i = 0; context.nextKeyValue(); i++) { if (i % downsamplingFactor == 0) { map(context.getCurrentKey(), context.getCurrentValue(), context); } } cleanup(context); } } /** * Reducer which uses MultipleOutputs to randomly allocate key value pairs * between test and training outputs */ public static class SplitInputReducer extends Reducer<WritableComparable<?>, Writable, WritableComparable<?>, Writable> { private MultipleOutputs multipleOutputs; private OutputCollector<WritableComparable<?>, Writable> trainingCollector = null; private OutputCollector<WritableComparable<?>, Writable> testCollector = null; private final Random rnd = RandomUtils.getRandom(); private float randomSelectionPercent; @SuppressWarnings("unchecked") @Override protected void setup(Context context) throws IOException { randomSelectionPercent = context.getConfiguration().getFloat(RANDOM_SELECTION_PCT, 0); multipleOutputs = new MultipleOutputs(new JobConf(context.getConfiguration())); trainingCollector = multipleOutputs.getCollector(TRAINING_TAG, null); testCollector = multipleOutputs.getCollector(TEST_TAG, null); } /** * Randomly allocate key value pairs between test and training sets. * randomSelectionPercent of the pairs will go to the test set. */ @Override protected void reduce(WritableComparable<?> key, Iterable<Writable> values, Context context) throws IOException, InterruptedException { for (Writable value : values) { if (rnd.nextInt(100) < randomSelectionPercent) { testCollector.collect(key, value); } else { trainingCollector.collect(key, value); } } } @Override protected void cleanup(Context context) throws IOException { multipleOutputs.close(); } } /** * Randomly permute key value pairs */ public static class SplitInputComparator extends WritableComparator { private final Random rnd = RandomUtils.getRandom(); protected SplitInputComparator() { super(WritableComparable.class); } @Override public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) { if (rnd.nextBoolean()) { return 1; } else { return -1; } } } }