package quickml.supervised.tree.regressionTree.branchFinders;
import com.google.common.base.Optional;
import quickml.supervised.tree.attributeIgnoringStrategies.AttributeIgnoringStrategy;
import quickml.supervised.tree.attributeValueIgnoringStrategies.AttributeValueIgnoringStrategy;
import quickml.supervised.tree.branchFinders.SortableLabelsCategoricalBranchFinder;
import quickml.supervised.tree.branchFinders.SplittingUtils;
import quickml.supervised.tree.constants.BranchType;
import quickml.supervised.tree.nodes.*;
import quickml.supervised.tree.branchingConditions.BranchingConditions;
import quickml.supervised.tree.reducers.AttributeStats;
import quickml.supervised.tree.regressionTree.nodes.RTCatBranch;
import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter;
import quickml.supervised.tree.scorers.ScorerFactory;
import java.util.Set;
/**
* Created by alexanderhawk on 6/11/15.
*/
public class RTCatBranchFinder extends SortableLabelsCategoricalBranchFinder<MeanValueCounter> {
@Override
public BranchType getBranchType() {
return BranchType.RT_CATEGORICAL;
}
public RTCatBranchFinder(Set<String> candidateAttributes, BranchingConditions<MeanValueCounter> branchingConditions, ScorerFactory<MeanValueCounter> scorerFactory, AttributeValueIgnoringStrategy<MeanValueCounter> attributeValueIgnoringStrategy, AttributeIgnoringStrategy attributeIgnoringStrategy) {
super(candidateAttributes, branchingConditions, scorerFactory, attributeValueIgnoringStrategy, attributeIgnoringStrategy);
}
@Override
protected Optional<? extends Branch<MeanValueCounter>> createBranch(Branch<MeanValueCounter> parent, AttributeStats<MeanValueCounter> attributeStats, SplittingUtils.SplitScore splitScore) {
return Optional.of(new RTCatBranch(parent, attributeStats.getAttribute(), splitScore.trueSet,
splitScore.probabilityOfBeingInTrueSet, splitScore.score,
attributeStats.getAggregateStats()));
}
}