/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.ga.watchmaker.cd.hadoop; import java.io.IOException; import java.util.Random; import com.google.common.base.Preconditions; import com.google.common.io.Closeables; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.RecordReader; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; import org.apache.mahout.common.RandomUtils; import org.apache.mahout.common.RandomWrapper; /** * Separate the input data into a training and testing set. */ public final class DatasetSplit { private static final String SEED = "traintest.seed"; private static final String THRESHOLD = "traintest.threshold"; private static final String TRAINING = "traintest.training"; private final long seed; private final double threshold; private boolean training; /** * * @param seed * @param threshold * fraction of the total dataset that will be used for training */ public DatasetSplit(long seed, double threshold) { this.seed = seed; this.threshold = threshold; this.training = true; } public DatasetSplit(double threshold) { this(((RandomWrapper) RandomUtils.getRandom()).getSeed(), threshold); } public DatasetSplit(Configuration conf) { seed = getSeed(conf); threshold = getThreshold(conf); training = isTraining(conf); } public long getSeed() { return seed; } public double getThreshold() { return threshold; } public boolean isTraining() { return training; } public void setTraining(boolean training) { this.training = training; } public void storeJobParameters(Configuration conf) { conf.set(SEED, String.valueOf(seed)); conf.set(THRESHOLD, Double.toString(threshold)); conf.setBoolean(TRAINING, training); } static long getSeed(Configuration conf) { String seedstr = conf.get(SEED); Preconditions.checkArgument(seedstr != null, "Job parameter %s not found", SEED); return Long.parseLong(seedstr); } static double getThreshold(Configuration conf) { String thrstr = conf.get(THRESHOLD); Preconditions.checkArgument(thrstr != null, "Job parameter %s not found", THRESHOLD); return Double.parseDouble(thrstr); } static boolean isTraining(Configuration conf) { Preconditions.checkArgument(conf.get(TRAINING) != null, "Job parameter %s not found", TRAINING); return conf.getBoolean(TRAINING, true); } /** * a {@link RecordReader} that skips some lines from the * input. Uses a Random number generator with a specific seed to decide if a line will be skipped or not. */ public static class RndLineRecordReader extends RecordReader<LongWritable, Text> { private final RecordReader<LongWritable, Text> reader; private final Random rng; private final double threshold; private final boolean training; private final LongWritable k = new LongWritable(); private final Text v = new Text(); public RndLineRecordReader(RecordReader<LongWritable, Text> reader, Configuration conf) { Preconditions.checkArgument(reader != null, "Null reader"); this.reader = reader; DatasetSplit split = new DatasetSplit(conf); rng = RandomUtils.getRandom(split.getSeed()); threshold = split.getThreshold(); training = split.isTraining(); } @Override public void close() throws IOException { Closeables.closeQuietly(reader); } @Override public float getProgress() throws IOException { try { return reader.getProgress(); } catch (InterruptedException e) { return 0.0f; } } @Override public LongWritable getCurrentKey() throws IOException, InterruptedException { return k; } @Override public Text getCurrentValue() throws IOException, InterruptedException { return v; } @Override public void initialize(InputSplit split, TaskAttemptContext context) throws IOException, InterruptedException { reader.initialize(split, context); } @Override public boolean nextKeyValue() throws IOException, InterruptedException { boolean read; do { read = reader.nextKeyValue(); } while (read && !selected()); if (!read) { return false; } k.set(reader.getCurrentKey().get()); v.set(reader.getCurrentValue()); return true; } /** * * @return true if the current input line is not skipped. */ private boolean selected() { return training ? rng.nextDouble() < threshold : rng.nextDouble() >= threshold; } } /** * {@link TextInputFormat} that uses a {@link RndLineRecordReader} as a RecordReader */ public static class DatasetTextInputFormat extends TextInputFormat { @Override public RecordReader<LongWritable, Text> createRecordReader(InputSplit split, TaskAttemptContext context) { return new RndLineRecordReader(super.createRecordReader(split, context), context.getConfiguration()); } } }