package com.datascience.gal;
import com.datascience.core.algorithms.INewDataObserver;
import com.datascience.core.base.AssignedLabel;
import com.datascience.core.base.LObject;
import com.datascience.core.base.Worker;
import com.datascience.core.nominal.NominalAlgorithm;
import com.datascience.core.nominal.NominalProject;
import com.datascience.core.nominal.decision.DecisionEngine;
import com.datascience.core.nominal.decision.LabelProbabilityDistributionCostCalculators;
import com.datascience.core.nominal.decision.WorkerEstimator;
import com.datascience.datastoring.datamodels.full.MemoryJobStorage;
import org.apache.log4j.Logger;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import static org.junit.Assert.assertFalse;
/**
* User: artur
* Date: 5/28/13
*/
@RunWith(Parameterized.class)
public class EdgeCasesTest {
private static Logger logger = Logger.getLogger(EdgeCasesTest.class);
private static List<String> categories = Arrays.asList(new String[]{"category1", "category2", "category3"});
public static abstract class ProjectCreator{
abstract NominalProject create();
protected NominalProject getProject(NominalAlgorithm algorithm){
MemoryJobStorage js = new MemoryJobStorage();
return new NominalProject(
algorithm,
js.getNominalData("testid"),
js.getNominalResults("testid", categories));
}
}
public interface AssignsCreator{
AssignedLabel<String>[] create();
}
final static ProjectCreator[] PROJECT_CREATORS = new ProjectCreator[]{
new ProjectCreator() {
@Override
public NominalProject create() {
return getProject(new BatchDawidSkene());
}
@Override
public String toString(){
return "BDS";
}
},
new ProjectCreator() {
@Override
NominalProject create() {
IncrementalDawidSkene algorithm = new IncrementalDawidSkene();
algorithm.setEpsilon(0.0001);
algorithm.setIterations(10);
NominalProject project = getProject(algorithm);
project.getData().addNewUpdatableAlgorithm(algorithm);
return project;
}
@Override
public String toString(){
return "IDS";
}
},
// new ProjectCreator() {
// @Override
// public NominalProject create() {
// return getProject(new BatchMV());
// }
// @Override
// public String toString(){
// return "BMV";
// }
// },
// new ProjectCreator() {
// @Override
// public NominalProject create() {
// IncrementalMV algorithm = new IncrementalMV();
// NominalProject project = getProject(algorithm);
// project.getData().addNewUpdatableAlgorithm(algorithm);
// return project;
// }
// @Override
// public String toString(){
// return "IMV";
// }
// },
};
final static AssignsCreator[] ASSIGNS_CREATORS = new AssignsCreator[]{
new AssignsCreator() {
@Override
public AssignedLabel<String>[] create() {
return new AssignedLabel[]{};
}
@Override
public String toString(){
return "NO ASSIGNS";
}
},
new AssignsCreator() {
@Override
public AssignedLabel<String>[] create() {
return new AssignedLabel[]{new AssignedLabel<String>(new Worker("worker"), new LObject<String>("object"), categories.get(0))};
}
@Override
public String toString(){
return "JUST ONE ASSIGN";
}
},
new AssignsCreator() {
@Override
public AssignedLabel<String>[] create() {
return new AssignedLabel[]{
new AssignedLabel<String>(new Worker("worker1"), new LObject<String>("object"), categories.get(0)),
new AssignedLabel<String>(new Worker("worker2"), new LObject<String>("object"), categories.get(1)),
new AssignedLabel<String>(new Worker("worker3"), new LObject<String>("object"), categories.get(0)),
new AssignedLabel<String>(new Worker("worker4"), new LObject<String>("object"), categories.get(2))};
}
@Override
public String toString(){
return "ONE OBJECT WITH MANY ASSIGNS";
}
},
new AssignsCreator() {
@Override
public AssignedLabel<String>[] create() {
return new AssignedLabel[]{
new AssignedLabel<String>(new Worker("worker"), new LObject<String>("object1"), categories.get(0)),
new AssignedLabel<String>(new Worker("worker"), new LObject<String>("object2"), categories.get(1)),
new AssignedLabel<String>(new Worker("worker"), new LObject<String>("object3"), categories.get(0)),
new AssignedLabel<String>(new Worker("worker"), new LObject<String>("object4"), categories.get(2))};
}
@Override
public String toString(){
return "MANY OBJECTS EACH WITH ONE ASSIGN, ONE WORKER";
}
},
new AssignsCreator() {
@Override
public AssignedLabel<String>[] create() {
return new AssignedLabel[]{
new AssignedLabel<String>(new Worker("worker1"), new LObject<String>("object1"), categories.get(0)),
new AssignedLabel<String>(new Worker("worker2"), new LObject<String>("object2"), categories.get(1)),
new AssignedLabel<String>(new Worker("worker3"), new LObject<String>("object3"), categories.get(0)),
new AssignedLabel<String>(new Worker("worker4"), new LObject<String>("object4"), categories.get(2))};
}
@Override
public String toString(){
return "MANY OBJECTS EACH WITH ONE ASSIGN, MANY WORKERS";
}
}
};
public ProjectCreator projectCreator;
public AssignsCreator assignsCreator;
public NominalProject project;
public EdgeCasesTest(ProjectCreator creator, AssignsCreator assignsCreator){
this.projectCreator = creator;
this.assignsCreator = assignsCreator;
}
@Parameterized.Parameters(name= "alg: {0}, assigns: {1}")
public static Collection<Object[]> instancesToTest() {
Collection<Object[]> ret = new LinkedList<Object[]>();
for (ProjectCreator dc : PROJECT_CREATORS)
for (AssignsCreator ac : ASSIGNS_CREATORS)
ret.add(new Object[]{dc, ac});
return ret;
}
@Before
public void initialize(){
project = projectCreator.create();
project.initializeCategories(categories, null, null);
for (AssignedLabel<String> al : assignsCreator.create())
project.getData().addAssign(al);
if (!(project.getAlgorithm() instanceof INewDataObserver)){
project.getAlgorithm().compute();
}
}
@Test
public void testWorkerCost(){
WorkerEstimator we = new WorkerEstimator(LabelProbabilityDistributionCostCalculators.get("ExpectedCost"));
for (Double d : we.getCosts(project).values()) {
assertFalse(Double.isNaN(d));
}
}
@Test
public void testDataCost(){
DecisionEngine de = new DecisionEngine(LabelProbabilityDistributionCostCalculators.get("ExpectedCost"), null);
for (Double d : de.estimateMissclassificationCosts(project).values()) {
assertFalse(Double.isNaN(d));
}
}
}