/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.ml.eval; import org.junit.Assert; import org.junit.Test; import rapaio.core.CoreTools; import rapaio.data.Frame; import rapaio.data.Nominal; import rapaio.data.Var; import rapaio.datasets.Datasets; import java.io.IOException; import java.net.URISyntaxException; /** * Test for roc utility. * <p> * Created by <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a> on 1/26/16. */ public class ROCTest { @Test public void testRoc() throws IOException, URISyntaxException { Frame df = Datasets.loadIrisDataset(); Var score = df.var(0); Var clazz = df.var("class"); ROC roc = ROC.from(score, clazz, 3); Assert.assertEquals("> ROC printSummary\n" + "\n" + "threshold , fpr , tpr , acc \n" + "Infinity , 0 , 0 , 0.6666667 \n" + "7.6 , 0 , 0.12 , 0.7066667 \n" + "7.1 , 0 , 0.24 , 0.7466667 \n" + "6.8 , 0.03 , 0.34 , 0.76 \n" + "6.4 , 0.11 , 0.62 , 0.8 \n" + "6 , 0.24 , 0.86 , 0.7933333 \n" + "5.7 , 0.37 , 0.96 , 0.74 \n" + "5.3 , 0.56 , 0.98 , 0.62 \n" + "5 , 0.79 , 0.98 , 0.4666667 \n" + "4.6 , 0.95 , 1 , 0.3666667 \n" + "4.3 , 1 , 1 , 0.3333333 \n" + "\n" + "AUC: 0.8871\n", roc.summary()); Assert.assertEquals(0.8871, roc.auc(), 1e-20); double midValue = CoreTools.mean(score).value(); int midRow = roc.findRowForThreshold(midValue); Assert.assertEquals(0.3, roc.data().value(midRow, ROC.fpr), 1e-20); Assert.assertEquals(0.94, roc.data().value(midRow, ROC.tpr), 1e-20); Nominal pred = Nominal.from(df.rowCount(), row -> row % 2 == 0 ? "virginica" : "setosa"); Assert.assertEquals("> ROC printSummary\n" + "\n" + "threshold , fpr , tpr , acc \n" + "Infinity , 0 , 0 , 0.6666667 \n" + "7.6 , 0.04 , 0.04 , 0.6533333 \n" + "7.1 , 0.08 , 0.08 , 0.64 \n" + "6.8 , 0.14 , 0.12 , 0.6133333 \n" + "6.4 , 0.27 , 0.3 , 0.5866667 \n" + "6 , 0.45 , 0.44 , 0.5133333 \n" + "5.7 , 0.6 , 0.5 , 0.4333333 \n" + "5.3 , 0.77 , 0.56 , 0.34 \n" + "5 , 0.89 , 0.78 , 0.3333333 \n" + "4.6 , 0.97 , 0.96 , 0.34 \n" + "4.3 , 1 , 1 , 0.3333333 \n" + "\n" + "AUC: 0.4445\n", ROC.from(score, clazz, pred).summary()); } }