package quickml.supervised.crossValidation.lossfunctions; import org.junit.Test; import quickml.data.PredictionMap; import quickml.supervised.crossValidation.PredictionMapResult; import quickml.supervised.crossValidation.PredictionMapResults; import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.WeightedAUCCrossValLossFunction; import java.util.ArrayList; import java.util.Collections; import java.util.LinkedList; import java.util.List; import static com.google.common.collect.Lists.newArrayList; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** * Created by Chris on 5/5/2014. */ public class WeightedAUCCrossValLossFunctionTest { @Test(expected = RuntimeException.class) public void testOnlySupportBinaryClassifications() { WeightedAUCCrossValLossFunction crossValLoss = new WeightedAUCCrossValLossFunction("test1"); PredictionMap predictionMap = PredictionMap.newMap(); ArrayList<PredictionMapResult> predictionMapResults = newArrayList(); predictionMapResults.add(new PredictionMapResult(predictionMap, "label1", 1.0)); predictionMapResults.add(new PredictionMapResult(predictionMap, "label2", 1.0)); predictionMapResults.add(new PredictionMapResult(predictionMap, "label3", 1.0)); crossValLoss.getLoss(new PredictionMapResults(predictionMapResults)); } @Test public void testGetTotalLoss() { WeightedAUCCrossValLossFunction crossValLoss = new WeightedAUCCrossValLossFunction("test1"); List<PredictionMapResult> results = new LinkedList<>(); results.add(createPredictionMapResult("test1", 0.5, "test1")); results.add(createPredictionMapResult("test1", 0.3, "test1")); results.add(createPredictionMapResult("test1", 0.4, "test2")); results.add(createPredictionMapResult("test1", 0.2, "test2")); //AUC Points at 0:0 0:.5 .5:.5 1:.5 1:1 - expected area should be .25 assertEquals(.25, crossValLoss.getLoss(new PredictionMapResults(results)), 0.00001); } private PredictionMapResult createPredictionMapResult(final String label, final double prediction, final String actual) { PredictionMap map = PredictionMap.newMap(); map.put(label, prediction); return new PredictionMapResult(map, actual, 1.0); } @Test public void testSortDataByProbability() { List<WeightedAUCCrossValLossFunction.AUCData> aucDataList = getAucDataList(); //order by probability ascending Collections.sort(aucDataList); double probability = 0; for (WeightedAUCCrossValLossFunction.AUCData aucData : aucDataList) { assertTrue(aucData.getProbabilityOfPositiveClassification() >= probability); probability = aucData.getProbabilityOfPositiveClassification(); } } @Test public void testGetAUCPoint() { //FPR = FP / (FP + TN) //TRP = TP / (TP + FN) WeightedAUCCrossValLossFunction crossValLoss = new WeightedAUCCrossValLossFunction("test1"); WeightedAUCCrossValLossFunction.AUCPoint aucPoint = crossValLoss.getAUCPoint(2, 2, 0, 1); assertEquals(1.0, aucPoint.getFalsePositiveRate(), 0.001); assertEquals(2.0 / 3.0, aucPoint.getTruePositiveRate(), 0.001); aucPoint = crossValLoss.getAUCPoint(2, 1, 1, 1); assertEquals(0.5, aucPoint.getFalsePositiveRate(), 0.001); assertEquals(2.0 / 3.0, aucPoint.getTruePositiveRate(), 0.001); aucPoint = crossValLoss.getAUCPoint(2, 0, 0, 1); assertEquals(0.0, aucPoint.getFalsePositiveRate(), 0.001); assertEquals(2.0 / 3.0, aucPoint.getTruePositiveRate(), 0.001); aucPoint = crossValLoss.getAUCPoint(0, 1, 3, 0); assertEquals(0.25, aucPoint.getFalsePositiveRate(), 0.001); assertEquals(0.0, aucPoint.getTruePositiveRate(), 0.001); } private List<WeightedAUCCrossValLossFunction.AUCData> getAucDataList() { List<WeightedAUCCrossValLossFunction.AUCData> aucDataList = new ArrayList<>(); aucDataList.add(new WeightedAUCCrossValLossFunction.AUCData("test1", 1.0, 0.5)); aucDataList.add(new WeightedAUCCrossValLossFunction.AUCData("test0", 1.0, 0.3)); aucDataList.add(new WeightedAUCCrossValLossFunction.AUCData("test0", 1.0, 0.6)); aucDataList.add(new WeightedAUCCrossValLossFunction.AUCData("test1", 1.0, 0.2)); aucDataList.add(new WeightedAUCCrossValLossFunction.AUCData("test1", 1.0, 0.7)); aucDataList.add(new WeightedAUCCrossValLossFunction.AUCData("test1", 1.0, 0.8)); return aucDataList; } @Test public void testAUCWhenAlwaysPredict0() { WeightedAUCCrossValLossFunction crossValLoss = new WeightedAUCCrossValLossFunction("test1"); List<WeightedAUCCrossValLossFunction.AUCData> aucDataList = new ArrayList<>(); int dataSize = 9000; //mahout only stores 10000 data points, test against less than what they consider for (int i = 0; i < dataSize; i++) { String classification = "test0"; if (i % 5 == 0) { classification = "test1"; } aucDataList.add(new WeightedAUCCrossValLossFunction.AUCData(classification, 1.0, 0.0)); } //order by probability ascending Collections.sort(aucDataList); ArrayList<WeightedAUCCrossValLossFunction.AUCPoint> aucPoints = crossValLoss.getAUCPointsFromData(aucDataList); double aucCrossValLoss = crossValLoss.getAUCLoss(aucPoints); assertEquals(0.5, aucCrossValLoss, 1E-7); } }