package quickml.supervised.classifier.downsampling;
import com.google.common.collect.Maps;
import quickml.data.AttributesMap;
import quickml.data.PredictionMap;
import quickml.supervised.classifier.AbstractClassifier;
import quickml.supervised.classifier.Classifier;
import java.io.Serializable;
import java.util.Map;
import java.util.Set;
/**
* Created by ian on 4/22/14.
*/
public class DownsamplingClassifier extends AbstractClassifier {
private static final long serialVersionUID = -265699047882740160L;
public final Classifier wrappedClassifier;
private final Serializable minorityClassification;
private final Serializable majorityClassification;
private final double dropProbability;
public DownsamplingClassifier(final Classifier wrappedClassifier, final Serializable majorityClassification, final Serializable minorityClassification, final double dropProbability) {
this.wrappedClassifier = wrappedClassifier;
this.majorityClassification = majorityClassification;
this.minorityClassification = minorityClassification;
this.dropProbability = dropProbability;
}
public double getProbability(AttributesMap attributes, Serializable classification) {
double uncorrectedProbability = wrappedClassifier.getProbability(attributes, minorityClassification);
double probabilityOfMinorityInstance = DownsamplingUtils.correctProbability(dropProbability, uncorrectedProbability);
if (classification.equals(minorityClassification)) {
return probabilityOfMinorityInstance;
} else {
return 1 - probabilityOfMinorityInstance;
}
}
@Override
public double getProbabilityWithoutAttributes(AttributesMap attributes, Serializable classification, Set<String> attributesToIgnore) {
double uncorrectedProbability = wrappedClassifier.getProbabilityWithoutAttributes(attributes, minorityClassification, attributesToIgnore);
double probabilityOfMinorityInstance = DownsamplingUtils.correctProbability(dropProbability, uncorrectedProbability);
if (classification.equals(minorityClassification)) {
return probabilityOfMinorityInstance;
} else {
return 1 - probabilityOfMinorityInstance;
}
}
@Override
public PredictionMap predict(AttributesMap attributes) {
Map<Serializable, Double> probsByClassification = Maps.newHashMap();
probsByClassification.put(minorityClassification, getProbability(attributes, minorityClassification));
probsByClassification.put(majorityClassification, getProbability(attributes, majorityClassification));
return new PredictionMap(probsByClassification);
}
@Override
public PredictionMap predictWithoutAttributes(AttributesMap attributes, Set<String> attributesToIgnore) {
Map<Serializable, Double> probsByClassification = Maps.newHashMap();
probsByClassification.put(minorityClassification, getProbabilityWithoutAttributes(attributes, minorityClassification, attributesToIgnore));
probsByClassification.put(majorityClassification, getProbabilityWithoutAttributes(attributes, majorityClassification, attributesToIgnore));
return new PredictionMap(probsByClassification);
}
@Override
public Serializable getClassificationByMaxProb(final AttributesMap attributes) {
return wrappedClassifier.getClassificationByMaxProb(attributes);
}
public double getDropProbability() {
return dropProbability;
}
public Serializable getMajorityClassification() {
return majorityClassification;
}
}