package quickml.supervised.tree;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import quickml.data.instances.InstanceWithAttributesMap;
import quickml.supervised.Utils;
import quickml.supervised.tree.branchFinders.BranchFinderAndReducerFactory;
import quickml.supervised.tree.reducers.Reducer;
import quickml.supervised.tree.reducers.ReducerFactory;
import quickml.supervised.tree.summaryStatistics.ValueCounter;
import quickml.supervised.tree.branchFinders.BranchFinder;
import quickml.supervised.tree.nodes.Branch;
import quickml.supervised.tree.nodes.Leaf;
import quickml.supervised.tree.nodes.Node;
import quickml.supervised.tree.branchingConditions.BranchingConditions;
import quickml.supervised.tree.treeBuildContexts.TreeContext;
import quickml.supervised.tree.treeBuildContexts.TreeContextBuilder;
import java.io.Serializable;
import java.util.*;
public class TreeBuilderHelper<I extends InstanceWithAttributesMap<?>, VC extends ValueCounter<VC>> {
protected TreeContextBuilder<I, VC> treeContextBuilder;
public TreeBuilderHelper(TreeContextBuilder<I, VC> treeContextBuilder) {
this.treeContextBuilder = treeContextBuilder.copy();
}
public TreeBuilderHelper<I, VC> copy() {
return new TreeBuilderHelper(treeContextBuilder);
}
public void updateBuilderConfig(Map<String, Serializable> cfg) {
treeContextBuilder.setConfig(cfg);
}
public Node<VC> computeNodes(List<I> trainingData) {
TreeContext<I, VC> itbc = treeContextBuilder.buildContext(trainingData);
return createNode(null, trainingData, itbc);
}
protected Node<VC> createNode(Branch<VC> parent, List<I> trainingData, TreeContext<I, VC> tc) {
Preconditions.checkArgument(trainingData != null && !trainingData.isEmpty(), "Can't build a oldTree with no training data");
BranchingConditions<VC> branchingConditions = tc.getBranchingConditions();
VC aggregateStats = getAggregateStats(tc, trainingData);
if (!branchingConditions.canTryAddingChildren(parent, aggregateStats)) {
return getLeaf(parent, aggregateStats, tc);
}
Optional<? extends Branch<VC>> bestBranchOptional = findBestBranch(parent, trainingData, tc);
if (!bestBranchOptional.isPresent()) {
return getLeaf(parent, aggregateStats, tc);
}
Branch<VC> bestBranch = bestBranchOptional.get();
Utils.TrueFalsePair<I> trueFalsePair = Utils.setTrueAndFalseTrainingSets(trainingData, bestBranch);
if (trueFalsePair.trueTrainingSet.size() ==0 || trueFalsePair.falseTrainingSet.size() ==0){//trueFalsePair.falseTrainingSet.size() ==0 || trueFalsePair.trueTrainingSet.size() ==0) {
return getLeaf(parent, aggregateStats, tc);
}
bestBranch.setTrueChild(createNode(bestBranch, trueFalsePair.trueTrainingSet, tc));
bestBranch.setFalseChild(createNode(bestBranch, trueFalsePair.falseTrainingSet, tc));
return bestBranch;
}
private Optional<? extends Branch<VC>> findBestBranch(Branch parent, List<I> instances, TreeContext<I, VC> tc ) {
double bestScore = 0;
Optional<? extends Branch<VC>> bestBranchOptional = Optional.absent();
List<? extends BranchFinderAndReducerFactory<I, VC>> branchFindersAndReducers = tc.getBranchFindersAndReducers();
//check how RT cat branch getting added twice?
for (BranchFinderAndReducerFactory<I, VC> branchFinderAndReducerFactory : branchFindersAndReducers) {
//important to keep the reduction of instances to ValueCounters separate from branchFinders, which don't need to know anything about the form of the instances
ReducerFactory<I, VC> reducerFactory = branchFinderAndReducerFactory.getReducerFactory();
Reducer<I, VC> reducer = reducerFactory.getReducer(instances);
BranchFinder<VC> branchFinder = branchFinderAndReducerFactory.getBranchFinder();
Optional<? extends Branch<VC>> thisBranchOptional = branchFinder.findBestBranch(parent, reducer); //decoupling occurs bc trainingDataReducer implements a simpler interface than TraingDataReducer
if (thisBranchOptional.isPresent()) {
Branch<VC> thisBranch = thisBranchOptional.get();
if (isBestSplitSoFar(tc, bestScore, thisBranch)) {
bestBranchOptional = thisBranchOptional;
bestScore = thisBranch.score;
}
}
}
return bestBranchOptional;
}
private boolean isBestSplitSoFar(TreeContext<I, VC> itbc, double bestScore, Branch<VC> thisBranch) {
return thisBranch.getScore()> bestScore && !itbc.getBranchingConditions().isInvalidSplit(thisBranch.getScore());
}
protected Leaf<VC> getLeaf(Branch<VC> parent, VC valueCounter, TreeContext<I, VC> itbc) {
Leaf<VC> vcnLeaf = itbc.getLeafBuilder().buildLeaf(parent, valueCounter);
return vcnLeaf;
}
private VC getAggregateStats(TreeContext<I, VC> itbc,List<I> trainingData) {
return itbc.getValueCounterProducer().getValueCounter(trainingData);
}
}