package quickml.supervised.crossValidation.lossfunctions;
import com.google.common.collect.Lists;
import quickml.data.PredictionMap;
import quickml.supervised.crossValidation.PredictionMapResult;
import quickml.supervised.crossValidation.PredictionMapResults;
import quickml.supervised.classifier.downsampling.DownsamplingUtils;
import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierLossFunction;
import java.io.Serializable;
import java.util.List;
/**
* Created by alexanderhawk on 10/23/14.
*/
public class LossFunctionCorrectedForDownsampling extends ClassifierLossFunction {
ClassifierLossFunction wrappedLossFunction;
CorrectionFunction correctionFunction;
public LossFunctionCorrectedForDownsampling(ClassifierLossFunction wrappedLossFunction, CorrectionFunction correctionFunction) {
this.correctionFunction = correctionFunction;
this.wrappedLossFunction = wrappedLossFunction;
}
public LossFunctionCorrectedForDownsampling(ClassifierLossFunction wrappedLossFunction, double dropProbability, Serializable negativeLabel) {
this.correctionFunction = new NegativeInstanceCorrectionFunction(negativeLabel, dropProbability);
this.wrappedLossFunction = wrappedLossFunction;
}
@Override
public Double getLoss(PredictionMapResults results) {
PredictionMapResults correctedLabelPredictionWeights = correctLabelPredictionWeights(results);
return wrappedLossFunction.getLoss(correctedLabelPredictionWeights);
}
@Override
public String getName() {
return "DOWNSAMPLED_" + wrappedLossFunction.getName();
}
public PredictionMapResults correctLabelPredictionWeights(PredictionMapResults uncorrectedPredictionMapResults) {
List<PredictionMapResult> results = Lists.newArrayList();
for (PredictionMapResult result : uncorrectedPredictionMapResults) {
results.add(correctionFunction.getCorrectedLabelPredictionWeight(result));
}
return new PredictionMapResults(results);
}
// TODO[mk] - internal class, doesn't need to be an interface
public interface CorrectionFunction {
PredictionMapResult getCorrectedLabelPredictionWeight(LabelPredictionWeight<Serializable, PredictionMap> labelPredictionWeight);
}
public class NegativeInstanceCorrectionFunction implements CorrectionFunction {
/* This class assumes instances have positive or negative instances
*/
Serializable negativeLabel = Double.valueOf(0.0);
double dropProbability;
NegativeInstanceCorrectionFunction(Serializable negativeLabel, double dropProbability) {
this.negativeLabel = negativeLabel;
this.dropProbability = dropProbability;
}
NegativeInstanceCorrectionFunction(double dropProbability) {
this.dropProbability = dropProbability;
}
@Override
public PredictionMapResult getCorrectedLabelPredictionWeight(LabelPredictionWeight<Serializable, PredictionMap> labelPredictionWeight) {
PredictionMap correctedPredictionMap = PredictionMap.newMap();
PredictionMap uncorrectedPrediction = labelPredictionWeight.getPrediction();
double correctedProbability;
for (Serializable key : uncorrectedPrediction.keySet()) {
if (key.equals(negativeLabel)) {
correctedProbability = 1.0 - DownsamplingUtils.correctProbability(dropProbability, 1.0 - uncorrectedPrediction.get(key));
correctedPredictionMap.put(key, correctedProbability);
} else {
correctedProbability = DownsamplingUtils.correctProbability(dropProbability, uncorrectedPrediction.get(key));
correctedPredictionMap.put(key, correctedProbability);
}
}
double correctedWeight = labelPredictionWeight.getWeight();
if (labelPredictionWeight.getLabel().equals(negativeLabel))
correctedWeight/=(1.0 - dropProbability);
return new PredictionMapResult(correctedPredictionMap, labelPredictionWeight.label, correctedWeight);
}
}
}