package hip.ch6.sampler;
import com.twitter.elephantbird.util.HadoopCompat;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hadoop.mapreduce.*;
import org.apache.hadoop.util.ReflectionUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
public class ReservoirSamplerInputFormat<K extends Writable, V extends Writable>
extends InputFormat {
public static final String INPUT_FORMAT_CLASS =
"reservoir.inputformat.class";
public static final String SAMPLES_NUMBER =
"reservoir.samples.number";
public static final String USE_SAMPLES_NUMBER_PER_INPUT_SPLIT =
"reservoir.samples.useperinputsplit";
public static final String MAXRECORDS_READ =
"reservoir.samples.maxrecordsread";
public static final int DEFAULT_NUM_SAMPLES = 1000;
public static final boolean DEFAULT_USE_SAMPLES_PER_INPUT_SPLIT =
false;
public static final int DEFAULT_MAX_RECORDS_READ = 100000;
private InputFormat<K, V> inputFormat;
/**
* A helper function to configure the actual InputFormat for the job.
*
* @param inputFormat the input format which will be wrapped by
* the sampler
*/
public static void setInputFormat(Job job,
Class<? extends InputFormat> inputFormat) {
job.getConfiguration().setClass(INPUT_FORMAT_CLASS, inputFormat,
InputFormat.class);
job.setInputFormatClass(ReservoirSamplerInputFormat.class);
}
public static void setNumSamples(Job job,
int numSamples) {
job.getConfiguration().setInt(SAMPLES_NUMBER, numSamples);
}
public static void setMaxRecordsToRead(Job job,
int maxRecords) {
job.getConfiguration().setInt(MAXRECORDS_READ, maxRecords);
}
public static void setUseSamplesNumberPerInputSplit(Job job,
boolean usePerInputSplit) {
job.getConfiguration().setBoolean(
USE_SAMPLES_NUMBER_PER_INPUT_SPLIT, usePerInputSplit);
}
public static int getNumSamples(Configuration conf) {
int numSamples = conf.getInt(SAMPLES_NUMBER, DEFAULT_NUM_SAMPLES);
boolean usePerSample =
conf.getBoolean(USE_SAMPLES_NUMBER_PER_INPUT_SPLIT,
DEFAULT_USE_SAMPLES_PER_INPUT_SPLIT);
if (usePerSample) {
return numSamples;
}
int numMapTasks = conf.getInt("mapred.map.tasks", 1);
return (int) Math.ceil(numSamples / numMapTasks);
}
public static int getMaxRecordsToRead(Configuration conf) {
return conf.getInt(MAXRECORDS_READ, DEFAULT_MAX_RECORDS_READ);
}
@SuppressWarnings("unchecked")
public InputFormat<K, V> getInputFormat(Configuration conf)
throws IOException {
if (inputFormat == null) {
Class ifClass = conf.getClass(INPUT_FORMAT_CLASS, null);
if (ifClass == null) {
throw new IOException("Job must be configured with " +
INPUT_FORMAT_CLASS);
}
inputFormat = (InputFormat<K, V>) ReflectionUtils
.newInstance(ifClass, conf);
}
return inputFormat;
}
@Override
public List<InputSplit> getSplits(JobContext context)
throws IOException, InterruptedException {
return getInputFormat(HadoopCompat.getConfiguration(context))
.getSplits(context);
}
@Override
@SuppressWarnings("unchecked")
public RecordReader createRecordReader(InputSplit split,
TaskAttemptContext context)
throws IOException, InterruptedException {
Configuration conf = HadoopCompat.getConfiguration(context);
return new ReservoirSamplerRecordReader(context,
getInputFormat(conf).createRecordReader(split, context),
getNumSamples(conf),
getMaxRecordsToRead(conf));
}
public static class ReservoirSamplerRecordReader<K extends Writable, V extends Writable>
extends RecordReader {
private final Configuration conf;
private final RecordReader<K, V> rr;
private final int numSamples;
private final int maxRecords;
private final ArrayList<K> keys;
private final ArrayList<V> values;
private int idx = 0;
public ReservoirSamplerRecordReader(TaskAttemptContext context,
RecordReader<K, V> rr,
int numSamples,
int maxRecords) {
this.conf = HadoopCompat.getConfiguration(context);
this.rr = rr;
this.numSamples = numSamples;
this.maxRecords = maxRecords;
keys = new ArrayList<K>(numSamples);
values = new ArrayList<V>(numSamples);
}
@Override
public void initialize(InputSplit split,
TaskAttemptContext context)
throws IOException, InterruptedException {
rr.initialize(split, context);
Random rand = new Random();
for (int i = 0; i < maxRecords; i++) {
if (!rr.nextKeyValue()) {
break;
}
K key = rr.getCurrentKey();
V val = rr.getCurrentValue();
if (keys.size() < numSamples) {
keys.add(WritableUtils.clone(key, conf));
values.add(WritableUtils.clone(val, conf));
} else {
int r = rand.nextInt(i);
if (r < numSamples) {
keys.set(r, WritableUtils.clone(key, conf));
values.set(r, WritableUtils.clone(val, conf));
}
}
}
}
@Override
public boolean nextKeyValue()
throws IOException, InterruptedException {
return idx++ < keys.size();
}
@Override
public K getCurrentKey()
throws IOException, InterruptedException {
return keys.get(idx - 1);
}
@Override
public Object getCurrentValue()
throws IOException, InterruptedException {
return values.get(idx - 1);
}
@Override
public float getProgress()
throws IOException, InterruptedException {
return Math.min(idx, keys.size()) / keys.size();
}
@Override
public void close() throws IOException {
rr.close();
}
}
}