package quickml.supervised.classifier.twoStageModel;
import quickml.data.AttributesMap;
import quickml.data.PredictionMap;
import quickml.supervised.classifier.AbstractClassifier;
import quickml.supervised.classifier.Classifier;
import java.io.Serializable;
import java.util.Set;
/**
* Created by alexanderhawk on 10/7/14.
*/
public class TwoStageClassifier extends AbstractClassifier {
/*
This class wraps 2 binary classifiers, in composite model (which is also a binary classifier) that predicts the probability
of an instance having positive labels in both situations the wrapped classifiers respectively make predictions for.
*/
Classifier wrappedOne;
Classifier wrappedTwo;
public TwoStageClassifier(Classifier wrappedOne, Classifier wrappedTwo) {
this.wrappedOne = wrappedOne;
this.wrappedTwo = wrappedTwo;
}
@Override
public PredictionMap predict(AttributesMap attributes) {
PredictionMap predictionMap = PredictionMap.newMap();
double adjustedPosProb = wrappedOne.getProbability(attributes, 1.0);
predictionMap.put(1.0, adjustedPosProb);
predictionMap.put(0.0, 1.0 - adjustedPosProb);
return predictionMap;
}
@Override
public PredictionMap predictWithoutAttributes(AttributesMap attributes, Set<String> attributesToIgnore) {
PredictionMap predictionMap = PredictionMap.newMap();
double adjustedPosProb = wrappedOne.getProbabilityWithoutAttributes(attributes, 1.0, attributesToIgnore);
predictionMap.put(1.0, adjustedPosProb);
predictionMap.put(0.0, 1.0 - adjustedPosProb);
return predictionMap;
}
@Override
public double getProbability(AttributesMap attributesMap, Serializable label) {
return wrappedOne.getProbability(attributesMap, label)*wrappedTwo.getProbability(attributesMap, label);
}
@Override
public double getProbabilityWithoutAttributes(AttributesMap attributesMap, Serializable label, Set<String> attributesToIgnore) {
return wrappedOne.getProbabilityWithoutAttributes(attributesMap, label, attributesToIgnore)*wrappedTwo.getProbabilityWithoutAttributes(attributesMap, label, attributesToIgnore);
}
}