package com.datascience.core.results; import com.datascience.core.base.AssignedLabel; import com.datascience.core.base.CategoryPair; import com.datascience.core.base.LObject; import com.datascience.core.stats.*; import com.google.common.base.Objects; import java.util.Collection; import java.util.HashMap; import java.util.Map; /** * User: artur */ public class WorkerResult { // The error matrix for the worker public ConfusionMatrix cm; //The confusion matrix for the worker based on evaluation data protected ConfusionMatrix eval_cm; public WorkerResult(Collection<String> categories){ cm = new MultinomialConfusionMatrix(categories); } public double getErrorRate(IErrorRateCalculator erc, String categoryFrom, String categoryTo){ return erc.getErrorRate(cm, categoryFrom, categoryTo); } public double getEvalErrorRate(String from, String to){ return eval_cm.getErrorRateBatch(from, to); } public void empty() { cm.empty(); } public ConfusionMatrix getConfusionMatrix(){ return cm; } public Map<String, Double> getPrior(Collection<AssignedLabel<String>> workerAssigns, Collection<String> categories){ int sum = workerAssigns.size(); HashMap<String, Double> worker_prior = new HashMap<String, Double>(); for (String category : categories) { if (sum>0) { double cnt = 0; for (AssignedLabel<String> al : workerAssigns) if (al.getLabel().equals(category)) cnt += 1.; Double prob = cnt / sum; worker_prior.put(category, prob); } else { worker_prior.put(category, 1.0/categories.size()); } } return worker_prior; } public void addError(String source, String destination, double error) { cm.addError(source, destination, error); } public void removeError(String source, String destination, double error) { cm.removeError(source, destination, error); } public void normalize(ConfusionMatrixNormalizationType type) { switch (type) { case UNIFORM: cm.normalize(); break; case LAPLACE: cm.normalizeLaplacean(); break; } } public void computeEvalConfusionMatrix(Collection<String> categories, Collection<AssignedLabel<String>> workerAssigns, Collection<LObject<String>> evaluationObjects) { eval_cm = new MultinomialConfusionMatrix(categories, new HashMap<CategoryPair, Double>()); for (AssignedLabel<String> l : workerAssigns) { for (LObject<String> eo : evaluationObjects){ if (eo.equals(l.getLobject())){ String assignedCategory = l.getLabel(); String correctCategory = eo.getEvaluationLabel(); eval_cm.addError(correctCategory, assignedCategory, 1.0); break; } } } eval_cm.normalize(); } @Override public boolean equals(Object other){ if (other instanceof WorkerResult) { return Objects.equal(cm, ((WorkerResult) other).cm); } return false; } @Override public int hashCode(){ return Objects.hashCode(cm); } }