package quickml.supervised.tree.decisionTree.reducers;
import com.google.common.base.Optional;
import com.google.common.collect.Ordering;
import quickml.data.instances.ClassifierInstance;
import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter;
import quickml.supervised.tree.reducers.AttributeStats;
import java.io.Serializable;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
/**
* Created by alexanderhawk on 4/22/15.
*/
public class DTBinaryCatBranchReducer<I extends ClassifierInstance> extends DTCatBranchReducer<I> {
final Serializable minorityClassification;
public DTBinaryCatBranchReducer(List<I> trainingData, Serializable minorityClassification) {
super(trainingData);
this.minorityClassification = minorityClassification;
}
@Override
public Optional<AttributeStats<ClassificationCounter>> getAttributeStats(String attribute) {
Optional<AttributeStats<ClassificationCounter>> attributeStatsOptional = super.getAttributeStats(attribute);
if (!attributeStatsOptional.isPresent()) {
return Optional.absent();
}
AttributeStats<ClassificationCounter> attributeStats = attributeStatsOptional.get();
List<ClassificationCounter> attributesWithClassificationCounters = attributeStats.getStatsOnEachValue();
Collections.sort(attributesWithClassificationCounters, new Comparator<ClassificationCounter>() {
@Override
public int compare(ClassificationCounter cc1, ClassificationCounter cc2) {
double probOfMinority1 = cc1.getCount(minorityClassification) / cc1.getTotal();
double probOfMinority2 = cc2.getCount(minorityClassification) / cc2.getTotal();
return Ordering.natural().reverse().compare(probOfMinority1, probOfMinority2);
}
});
return Optional.of(attributeStats);
}
}