package com.datascience.serialization.json;
import com.datascience.datastoring.datamodels.memory.IncrementalNominalModel;
import com.datascience.datastoring.jobs.Job;
import com.datascience.datastoring.jobs.JobFactory;
import com.datascience.core.base.AssignedLabel;
import com.datascience.core.base.LObject;
import com.datascience.core.base.Worker;
import com.datascience.core.nominal.NominalProject;
import com.datascience.datastoring.datamodels.full.MemoryJobStorage;
import com.datascience.gal.IncrementalDawidSkene;
import com.google.gson.*;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import java.util.*;
/**
* User: artur
* Date: 3/27/13
*/
public class NominalModelTest {
Random random = new Random();
Gson gson;
ArrayList<String> categories;
public NominalModelTest() {
GsonBuilder builder = JSONUtils.getFilledDefaultGsonBuilder();
gson = builder.create();
}
private JsonArray createCategoriesJsonArray(ArrayList<String> categories){
JsonArray cat = new JsonArray();
for (String s : categories)
cat.add(new JsonPrimitive(s));
return cat;
}
private JsonArray createCategoryPriorsJsonArray(ArrayList<String> categories){
JsonArray ja = new JsonArray();
for (String s : categories){
JsonObject cv = new JsonObject();
cv.addProperty("categoryName", s);
cv.addProperty("value", 1. / categories.size());
ja.add(cv);
}
return ja;
}
private Collection<AssignedLabel<String>> createAssigns(int o, int w, ArrayList<String> categories){
Collection<AssignedLabel<String>> ret = new ArrayList<AssignedLabel<String>>();
ArrayList<LObject<String>> lObjects = new ArrayList<LObject<String>>();
for (int i = 0; i < o; i++) {
LObject<String> lObject = new LObject<String>("object" + i);
lObjects.add(lObject);
}
for (int i = 0; i < w; i++) {
Worker worker = new Worker("worker" + i);
for (LObject<String> lObject : lObjects) {
AssignedLabel<String> assign = new AssignedLabel<String>(worker, lObject, categories.get(random.nextInt(categories.size())));
ret.add(assign);
}
}
return ret;
}
private NominalProject createProject(String alg, ArrayList<String> categories, boolean priors){
JobFactory jf = new JobFactory(new GSONSerializer(), new MemoryJobStorage());
JsonObject jo = new JsonObject();
jo.addProperty("algorithm", alg);
jo.add("categories", createCategoriesJsonArray(categories));
if (priors)
jo.add("categoryPriors", createCategoryPriorsJsonArray(categories));
Job job = jf.createNominalJob(JSONUtils.tKeys(jo), "test");
return (NominalProject)job.getProject();
}
@Before
public void setUp(){
categories = new ArrayList<String>();
categories.add("category1");
categories.add("category2");
}
@Test
public void notFixedPriorsTest() {
NominalProject project = createProject("IDS", categories, false);
for (AssignedLabel<String> al : createAssigns(3, 2, categories))
project.getData().addAssign(al);
Assert.assertFalse(project.getData().arePriorsFixed());
Assert.assertTrue(project.getAlgorithm().getModel() instanceof IncrementalNominalModel);
Assert.assertEquals(2, project.getAlgorithm().getModel().getCategoryPriors().size());
Assert.assertEquals(3, ((IncrementalNominalModel) project.getAlgorithm().getModel()).getPriorDenominator());
Assert.assertNull(project.getData().getCategoryPriors());
}
@Test
public void fixedPriorsTest() {
NominalProject project = createProject("IDS", categories, true);
for (AssignedLabel<String> al : createAssigns(3, 2, categories))
project.getData().addAssign(al);
Assert.assertTrue(project.getData().arePriorsFixed());
Assert.assertTrue(project.getAlgorithm().getModel() instanceof IncrementalNominalModel);
Assert.assertEquals(0, project.getAlgorithm().getModel().getCategoryPriors().size());
Assert.assertEquals(0, ((IncrementalNominalModel) project.getAlgorithm().getModel()).getPriorDenominator());
for (String s : categories)
Assert.assertEquals(0.5, ((IncrementalDawidSkene)project.getAlgorithm()).prior(s), 1e-6);
}
}