/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.df.mapreduce.inmem;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.InputFormat;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.JobContext;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.mahout.classifier.df.mapreduce.Builder;
import org.apache.mahout.common.RandomUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Random;
/**
* Custom InputFormat that generates InputSplits given the desired number of trees.<br>
* each input split contains a subset of the trees.<br>
* The number of splits is equal to the number of requested splits
*/
public class InMemInputFormat extends InputFormat<IntWritable,NullWritable> {
private static final Logger log = LoggerFactory.getLogger(InMemInputSplit.class);
private Random rng;
private Long seed;
private boolean isSingleSeed;
/**
* Used for DEBUG purposes only. if true and a seed is available, all the mappers use the same seed, thus
* all the mapper should take the same time to build their trees.
*/
private static boolean isSingleSeed(Configuration conf) {
return conf.getBoolean("debug.mahout.rf.single.seed", false);
}
@Override
public RecordReader<IntWritable,NullWritable> createRecordReader(InputSplit split,
TaskAttemptContext context) throws IOException,
InterruptedException {
Preconditions.checkArgument(split instanceof InMemInputSplit);
return new InMemRecordReader((InMemInputSplit) split);
}
@Override
public List<InputSplit> getSplits(JobContext context) throws IOException, InterruptedException {
Configuration conf = context.getConfiguration();
int numSplits = conf.getInt("mapred.map.tasks", -1);
return getSplits(conf, numSplits);
}
public List<InputSplit> getSplits(Configuration conf, int numSplits) {
int nbTrees = Builder.getNbTrees(conf);
int splitSize = nbTrees / numSplits;
seed = Builder.getRandomSeed(conf);
isSingleSeed = isSingleSeed(conf);
if (rng != null && seed != null) {
log.warn("getSplits() was called more than once and the 'seed' is set, "
+ "this can lead to no-repeatable behavior");
}
rng = seed == null || isSingleSeed ? null : RandomUtils.getRandom(seed);
int id = 0;
List<InputSplit> splits = Lists.newArrayListWithCapacity(numSplits);
for (int index = 0; index < numSplits - 1; index++) {
splits.add(new InMemInputSplit(id, splitSize, nextSeed()));
id += splitSize;
}
// take care of the remainder
splits.add(new InMemInputSplit(id, nbTrees - id, nextSeed()));
return splits;
}
/**
* @return the seed for the next InputSplit
*/
private Long nextSeed() {
if (seed == null) {
return null;
} else if (isSingleSeed) {
return seed;
} else {
return rng.nextLong();
}
}
public static class InMemRecordReader extends RecordReader<IntWritable,NullWritable> {
private final InMemInputSplit split;
private int pos;
private IntWritable key;
private NullWritable value;
public InMemRecordReader(InMemInputSplit split) {
this.split = split;
}
@Override
public float getProgress() throws IOException {
return pos == 0 ? 0.0f : (float) (pos - 1) / split.nbTrees;
}
@Override
public IntWritable getCurrentKey() throws IOException, InterruptedException {
return key;
}
@Override
public NullWritable getCurrentValue() throws IOException, InterruptedException {
return value;
}
@Override
public void initialize(InputSplit arg0, TaskAttemptContext arg1) throws IOException, InterruptedException {
key = new IntWritable();
value = NullWritable.get();
}
@Override
public boolean nextKeyValue() throws IOException, InterruptedException {
if (pos < split.nbTrees) {
key.set(split.firstId + pos);
pos++;
return true;
} else {
return false;
}
}
@Override
public void close() throws IOException {
}
}
/**
* Custom InputSplit that indicates how many trees are built by each mapper
*/
public static class InMemInputSplit extends InputSplit implements Writable {
private static final String[] NO_LOCATIONS = new String[0];
/** Id of the first tree of this split */
private int firstId;
private int nbTrees;
private Long seed;
public InMemInputSplit() { }
public InMemInputSplit(int firstId, int nbTrees, Long seed) {
this.firstId = firstId;
this.nbTrees = nbTrees;
this.seed = seed;
}
/**
* @return the Id of the first tree of this split
*/
public int getFirstId() {
return firstId;
}
/**
* @return the number of trees
*/
public int getNbTrees() {
return nbTrees;
}
/**
* @return the random seed or null if no seed is available
*/
public Long getSeed() {
return seed;
}
@Override
public long getLength() throws IOException {
return nbTrees;
}
@Override
public String[] getLocations() throws IOException {
return NO_LOCATIONS;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (!(obj instanceof InMemInputSplit)) {
return false;
}
InMemInputSplit split = (InMemInputSplit) obj;
if (firstId != split.firstId || nbTrees != split.nbTrees) {
return false;
}
if (seed == null) {
return split.seed == null;
} else {
return seed.equals(split.seed);
}
}
@Override
public int hashCode() {
return firstId + nbTrees + (seed == null ? 0 : seed.intValue());
}
@Override
public String toString() {
return String.format(Locale.ENGLISH, "[firstId:%d, nbTrees:%d, seed:%d]", firstId, nbTrees, seed);
}
@Override
public void readFields(DataInput in) throws IOException {
firstId = in.readInt();
nbTrees = in.readInt();
boolean isSeed = in.readBoolean();
seed = isSeed ? in.readLong() : null;
}
@Override
public void write(DataOutput out) throws IOException {
out.writeInt(firstId);
out.writeInt(nbTrees);
out.writeBoolean(seed != null);
if (seed != null) {
out.writeLong(seed);
}
}
public static InMemInputSplit read(DataInput in) throws IOException {
InMemInputSplit split = new InMemInputSplit();
split.readFields(in);
return split;
}
}
}