package test.java.integration.tests.gal; import com.datascience.core.nominal.decision.LabelProbabilityDistributionCostCalculators; import com.datascience.core.nominal.decision.WorkerEstimator; import com.datascience.mv.BatchMV; import com.datascience.mv.IncrementalMV; import junitparams.JUnitParamsRunner; import junitparams.Parameters; import org.junit.Test; import org.junit.runner.RunWith; import java.util.HashMap; import java.util.Map; import static junitparams.JUnitParamsRunner.$; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @RunWith(JUnitParamsRunner.class) public class MVBaseTestScenario extends BaseTestScenario { private static Map<String, Double> params; public static class Setup { public String algorithm; public String testName; public boolean loadEvaluationLabels; public Setup(String alg, String tName, boolean lEvaluationLabels) { algorithm = alg; testName = tName; loadEvaluationLabels = lEvaluationLabels; } } public static void initSetup(Setup testSetup) { if (testSetup.algorithm.equals("BMV")) { setUp(new BatchMV(), testSetup.testName, testSetup.loadEvaluationLabels); } else { setUp(new IncrementalMV(), testSetup.testName, testSetup.loadEvaluationLabels); } params = new HashMap<String, Double>(); params.put("Estm_MV_Exp", estimateMissclassificationCost(LabelProbabilityDistributionCostCalculators.get("EXPECTEDCOST"), null)); params.put("Estm_MV_ML", estimateMissclassificationCost(LabelProbabilityDistributionCostCalculators.get("MAXLIKELIHOOD"), null)); params.put("Estm_MV_Min", estimateMissclassificationCost(LabelProbabilityDistributionCostCalculators.get("MINCOST"), null)); params.put("Eval_MV_ML", evaluateMissclassificationCost("MAXLIKELIHOOD")); params.put("Eval_MV_Min", evaluateMissclassificationCost("MINCOST")); params.put("Eval_MV_Soft", evaluateMissclassificationCost("SOFT")); params.put("Estm_MV_ML_q", estimateCostToQuality(LabelProbabilityDistributionCostCalculators.get("MAXLIKELIHOOD"), null)); params.put("Estm_MV_Exp_q", estimateCostToQuality(LabelProbabilityDistributionCostCalculators.get("EXPECTEDCOST"), null)); params.put("Estm_MV_Min_q", estimateCostToQuality(LabelProbabilityDistributionCostCalculators.get("MINCOST"), null)); params.put("Eval_MV_ML_q", evaluateCostToQuality("MAXLIKELIHOOD")); params.put("Eval_MV_Soft_q", evaluateCostToQuality("SOFT")); params.put("Eval_MV_Min_q", evaluateCostToQuality("MINCOST")); } @Test @Parameters public void testDataCost(String p1, String p2) { HashMap<String, Double> dataQuality = summaryResultsParser.getDataQuality(); assertEquals(dataQuality.get(p2), params.get(p1), 0.05); } private Object[] parametersForTestDataCost() { return $( $("Estm_MV_Exp", "[DataCost_Estm_MV_Exp] Estimated classification cost (MV_Exp metric)"), $("Estm_MV_ML", "[DataCost_Estm_MV_ML] Estimated classification cost (MV_ML metric)"), $("Estm_MV_Min", "[DataCost_Estm_MV_Min] Estimated classification cost (MV_Min metric)"), $("Eval_MV_ML", "[DataCost_Eval_MV_ML] Actual classification cost for majority vote classification"), $("Eval_MV_Min", "[DataCost_Eval_MV_Min] Actual classification cost for naive min-cost classification"), $("Eval_MV_Soft", "[DataCost_Eval_MV_Soft] Actual classification cost for naive soft-label classification")); } @Test @Parameters public void testDataQuality(String p1, String p2) { HashMap<String, Double> dataQuality = summaryResultsParser.getDataQuality(); assertEquals(dataQuality.get(p2), params.get(p1), 0.05); } private Object[] parametersForTestDataQuality() { return $( $("Estm_MV_ML_q", "[DataQuality_Estm_MV_ML] Estimated data quality, naive majority label"), $("Estm_MV_Exp_q", "[DataQuality_Estm_MV_Exp] Estimated data quality, naive soft label"), $("Estm_MV_Min_q", "[DataQuality_Estm_MV_Min] Estimated data quality, naive mincost label"), $("Eval_MV_ML_q", "[DataQuality_Eval_MV_ML] Actual data quality, naive majority label"), $("Eval_MV_Min_q", "[DataQuality_Eval_MV_Min] Actual data quality, naive mincost label"), $("Eval_MV_Soft_q", "[DataQuality_Eval_MV_Soft] Actual data quality, naive soft label")); } @Test public void test_WorkerQualityIsNotNan() { WorkerEstimator we = new WorkerEstimator(LabelProbabilityDistributionCostCalculators.get("ExpectedCost")); for (Double d : we.getCosts(project).values()) { assertFalse(Double.isNaN(d)); } } }