package quickml.supervised.ensembles.randomForest.randomDecisionForest;
import com.google.common.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickml.data.instances.ClassifierInstance;
import quickml.data.PredictionMap;
import quickml.supervised.ensembles.randomForest.RandomForestBuilder;
import quickml.supervised.tree.decisionTree.DecisionTree;
import quickml.supervised.tree.attributeIgnoringStrategies.IgnoreAttributesWithConstantProbability;
import quickml.supervised.tree.decisionTree.DecisionTreeBuilder;
import java.io.Serializable;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import static quickml.supervised.tree.constants.ForestOptions.*;
public class RandomDecisionForestBuilder<I extends ClassifierInstance> extends RandomForestBuilder<PredictionMap, RandomDecisionForest, I> {
//TODO: copy treeBuilder before submitting
private static final Logger logger = LoggerFactory.getLogger(RandomDecisionForestBuilder.class);
private final DecisionTreeBuilder<I> treeBuilder;
private int executorThreadCount = Runtime.getRuntime().availableProcessors();
private ExecutorService executorService;
public RandomDecisionForestBuilder() {
this(new DecisionTreeBuilder<I>().attributeIgnoringStrategy(new IgnoreAttributesWithConstantProbability(0.7)).maxDepth(5));
}
public RandomDecisionForestBuilder(DecisionTreeBuilder<I> treeBuilder) {
this.treeBuilder = treeBuilder;
}
@Override
public void updateBuilderConfig(Map<String, Serializable> config) {
treeBuilder.updateBuilderConfig(config);
if (config.containsKey(NUM_TREES.name()))
this.numTrees((Integer) config.get(NUM_TREES.name()));
}
public RandomDecisionForestBuilder<I> numTrees(int numTrees) {
super.numTrees = numTrees;
return this;
}
public RandomDecisionForestBuilder<I> executorThreadCount(int threadCount) {
this.executorThreadCount = threadCount;
return this;
}
@Override
public RandomDecisionForest buildPredictiveModel(Iterable<I> trainingData) {
executorService = Executors.newFixedThreadPool(executorThreadCount);
logger.info("Building random forest with {} trees", numTrees);
List<Future<DecisionTree>> treeFutures = Lists.newArrayListWithCapacity(numTrees);
List<DecisionTree> decisionTrees = Lists.newArrayListWithCapacity(numTrees);
// Submit all oldTree building jobs to the executor
for (int treeIndex = 0; treeIndex < numTrees; treeIndex++) {
treeFutures.add(submitTreeBuild(trainingData, treeIndex));
}
// Collect all completed trees. Will block until complete
collectTreeFutures(decisionTrees, treeFutures);
Set<Serializable> classifications = new HashSet<>();
for (DecisionTree decisionTree : decisionTrees) {
classifications.addAll(decisionTree.getClassifications());
}
return new RandomDecisionForest(decisionTrees, classifications);
}
private Future<DecisionTree> submitTreeBuild(final Iterable<I> trainingData, final int treeIndex) {
return executorService.submit(new Callable<DecisionTree>() {
@Override
public DecisionTree call() throws Exception {
return buildModel(trainingData, treeIndex);
}
});
}
private DecisionTree buildModel(Iterable<I> trainingData, int treeIndex) {
logger.debug("Building oldTree {} of {}", treeIndex, numTrees);
return treeBuilder.copy().buildPredictiveModel(trainingData);
}
protected void collectTreeFutures(List<DecisionTree> decisionTrees, List<Future<DecisionTree>> treeFutures) {
for (Future<DecisionTree> treeFuture : treeFutures) {
collectTreeFutures(decisionTrees, treeFuture);
}
executorService.shutdown();
}
private void collectTreeFutures(List<DecisionTree> decisionTrees, Future<DecisionTree> treeFuture) {
try {
decisionTrees.add(treeFuture.get());
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
}