package quickml.supervised.crossValidation.lossfunctions;
import org.junit.Assert;
import org.junit.Test;
import quickml.data.PredictionMap;
import quickml.supervised.crossValidation.PredictionMapResult;
import quickml.supervised.crossValidation.PredictionMapResults;
import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierMSELossFunction;
import static com.google.common.collect.Lists.newArrayList;
public class ClassifierMSELossFunctionTest {
@Test
public void testGetTotalLoss() {
ClassifierMSELossFunction crossValLoss = new ClassifierMSELossFunction();
PredictionMapResult result1 = createPredictionMapResult("test1", 0.75, 2.0);
PredictionMapResult result2 = createPredictionMapResult("test1", 0.5, 1.0);
PredictionMapResults predictionMapResults = new PredictionMapResults(newArrayList(result1, result2));
Assert.assertEquals(0.125, crossValLoss.getLoss(predictionMapResults), 0.0001);
}
private PredictionMapResult createPredictionMapResult(final String label, final double prediction, final double weight) {
PredictionMap map = PredictionMap.newMap();
map.put(label, prediction);
return new PredictionMapResult(map, label, weight);
}
}