package com.datascience.gal;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import java.util.ArrayList;
import com.datascience.core.base.AssignedLabel;
import com.datascience.core.base.LObject;
import com.datascience.core.nominal.NominalProject;
import com.datascience.core.base.Worker;
import com.datascience.datastoring.datamodels.full.MemoryJobStorage;
import org.junit.Before;
import org.junit.Test;
import com.datascience.core.nominal.decision.DecisionEngine;
import com.datascience.core.nominal.decision.LabelProbabilityDistributionCostCalculators;
public class BatchDawidSkeneTest {
NominalProject project;
@Before
public void setUp(){
ArrayList<String> categories = new ArrayList<String>();
categories.add("category1");
categories.add("category2");
MemoryJobStorage js = new MemoryJobStorage();
project = new NominalProject(new BatchDawidSkene(), js.getNominalData("testid"), js.getNominalResults("testid", categories));
project.initializeCategories(categories, null, null);
}
@Test
public final void testAddLabelWithWrongCategory() {
Worker w = project.getData().getOrCreateWorker("worker");
LObject<String> obj = project.getData().getOrCreateObject("object1");
project.getData().addAssign(new AssignedLabel<String>(w, obj, "category1"));
try {
project.getData().addAssign(new AssignedLabel<String>(w, obj, "wrongLabel"));
fail("Added label with incorrect category.");
} catch(Exception e) {
}
assertEquals(project.getData().getAssigns().size(),1);
}
@Test
public final void testMissclassificationCost() {
LObject<String> obj = new LObject<String>("object1");
project.getData().addObject(obj);
DecisionEngine de = new DecisionEngine(LabelProbabilityDistributionCostCalculators.get(""), null);
try{
de.estimateMissclassificationCost(project, obj);
fail("trying to get estimated value for not computed object");
}
catch(Exception ex){
}
project.getAlgorithm().compute();
assertEquals(
1. / project.getData().getCategories().size(),
de.estimateMissclassificationCost(project, obj),
1e-10);
try{
de.estimateMissclassificationCost(project, new LObject<String>("testObj"));
fail("trying to get estimated value for non existions object");
}
catch(Exception e){
}
}
}