package com.datascience.gal.commands; import com.datascience.core.base.LObject; import com.datascience.core.nominal.NominalProject; import com.datascience.core.base.Worker; import com.datascience.core.nominal.Quality; import com.datascience.core.results.WorkerResult; import com.datascience.core.stats.ConfusionMatrix; import com.datascience.core.stats.QSPCalculators; import com.datascience.core.stats.QualitySensitivePaymentsCalculator; import com.datascience.datastoring.jobs.JobCommand; import com.datascience.core.stats.MatrixValue; import com.datascience.gal.*; import com.datascience.core.nominal.decision.*; import com.datascience.utils.MathHelpers; import java.util.*; import java.util.Map.Entry; import static com.datascience.core.nominal.Quality.getMinSpammerCost; /** * NOTE: avg Qualit = Qualit (avg Cost) * @author artur */ public class PredictionCommands { static public Collection<MatrixValue<String>> getWorkerConfusionMatrix(ConfusionMatrix confusionMatrix){ Collection<MatrixValue<String>> matrix = new ArrayList<MatrixValue<String>>(); for (String c1 : confusionMatrix.getCategories()) for (String c2 : confusionMatrix.getCategories()) matrix.add(new MatrixValue<String>(c1, c2, confusionMatrix.getNormalizedErrorRate(c1, c2))); return matrix; } static public class GetWorkersConfusionMatrix extends JobCommand<Collection<WorkerValue<Collection<MatrixValue<String>>>>, NominalProject> { public GetWorkersConfusionMatrix(){ super(false); } @Override protected void realExecute() { Collection<WorkerValue<Collection<MatrixValue<String>>>> wq = new ArrayList<WorkerValue<Collection<MatrixValue<String>>>>(); for (Entry<Worker, WorkerResult> e : project.getResults().getWorkerResults(project.getData().getWorkers()).entrySet()){ wq.add(new WorkerValue<Collection<MatrixValue<String>>>(e.getKey().getName(), getWorkerConfusionMatrix(e.getValue().getConfusionMatrix()))); } setResult(wq); } } static public class GetWorkerConfusionMatrix extends JobCommand<WorkerValue<Collection<MatrixValue<String>>>, NominalProject> { String wid; public GetWorkerConfusionMatrix(String wid){ super(false); this.wid = wid; } @Override protected void realExecute() { setResult(new WorkerValue<Collection<MatrixValue<String>>>(wid, getWorkerConfusionMatrix(project.getResults().getWorkerResult(project.getData().getWorker(wid)).getConfusionMatrix()))); } } static public class GetWorkersPayments extends JobCommand<Collection<WorkerValue<Double>>, NominalProject> { double qualifiedWage; double costThreshold; public GetWorkersPayments(double qualifiedWage, double costThreshold){ super(false); this.qualifiedWage = qualifiedWage; this.costThreshold = costThreshold; } @Override protected void realExecute() { Collection<WorkerValue<Double>> wq = new LinkedList<WorkerValue<Double>>(); for (Worker w : project.getData().getWorkers()){ QualitySensitivePaymentsCalculator wspq = new QSPCalculators.Linear(project, w); wq.add(new WorkerValue<Double>(w.getName(), wspq.getWorkerWage(qualifiedWage, costThreshold))); } setResult(wq); } } static public class GetWorkerPayment extends JobCommand<WorkerValue<Double>, NominalProject> { double qualifiedWage; double costThreshold; String wid; public GetWorkerPayment(String wid, double qualifiedWage, double costThreshold){ super(false); this.wid = wid; this.qualifiedWage = qualifiedWage; this.costThreshold = costThreshold; } @Override protected void realExecute() { QualitySensitivePaymentsCalculator wspq = new QSPCalculators.Linear(project, project.getData().getWorker(wid)); setResult(new WorkerValue<Double>(wid, wspq.getWorkerWage(qualifiedWage, costThreshold))); } } static public class GetWorkersQuality extends JobCommand<Collection<WorkerValue<Double>>, NominalProject> { private WorkerQualityCalculator wqc; public GetWorkersQuality(WorkerQualityCalculator wqc){ super(false); this.wqc = wqc; } @Override protected void realExecute() { Collection<WorkerValue<Double>> wq = new LinkedList<WorkerValue<Double>>(); for (Worker w : project.getData().getWorkers()){ wq.add(new WorkerValue<Double>(w.getName(), wqc.getQuality(project, w))); } setResult(wq); } } static public class GetWorkersCost extends JobCommand<Collection<WorkerValue<Double>>, NominalProject> { private WorkerQualityCalculator wqc; public GetWorkersCost(WorkerQualityCalculator wqc){ super(false); this.wqc = wqc; } @Override protected void realExecute() { Collection<WorkerValue<Double>> wq = new LinkedList<WorkerValue<Double>>(); for (Worker w : project.getData().getWorkers()){ wq.add(new WorkerValue<Double>(w.getName(), wqc.getCost(project, w))); } setResult(wq); } } static public class GetWorkersQualitySummary extends JobCommand<Map<String, Object>, NominalProject> { public GetWorkersQualitySummary(){ super(false); } @Override protected void realExecute() throws Exception { HashMap<String, Object> ret = new HashMap<String, Object>(); for (String s : new String[] {"ExpectedCost", "MinCost", "MaxLikelihood"}){ WorkerQualityCalculator wqc = new WorkerEstimator(LabelProbabilityDistributionCostCalculators.get(s)); ret.put(s, Quality.fromCost(project, MathHelpers.getAverageNotNaN(wqc.getCosts(project)))); } setResult(ret); } } static public class GetPredictedCategory extends JobCommand<Collection<DatumClassification>, NominalProject> { private DecisionEngine decisionEngine; public GetPredictedCategory(IObjectLabelDecisionAlgorithm lda){ super(false); decisionEngine = new DecisionEngine(null, lda); } @Override protected void realExecute() { Collection<DatumClassification> dc = new ArrayList<DatumClassification>(); for (Entry<LObject<String>, String> e : decisionEngine.predictLabels(project).entrySet()){ dc.add(new DatumClassification(e.getKey().getName(), e.getValue())); } setResult(dc); } } static public class GetDataCost extends JobCommand<Collection<DatumValue>, NominalProject> { private DecisionEngine decisionEngine; public GetDataCost(ILabelProbabilityDistributionCostCalculator lca){ super(false); decisionEngine = new DecisionEngine(lca, null); } @Override protected void realExecute() { Collection<DatumValue> cp = new ArrayList<DatumValue>(); for (Entry<LObject<String>, Double> e : decisionEngine.estimateMissclassificationCosts(project).entrySet()){ cp.add(new DatumValue(e.getKey().getName(), e.getValue())); } setResult(cp); } } static public class GetDataQuality extends JobCommand<Collection<DatumValue>, NominalProject> { private DecisionEngine decisionEngine; public GetDataQuality(ILabelProbabilityDistributionCostCalculator lca){ super(false); decisionEngine = new DecisionEngine(lca, null); } @Override protected void realExecute() { Collection<DatumValue> cp = new ArrayList<DatumValue>(); for (Entry<LObject<String>, Double> e : Quality.fromCosts(project, decisionEngine.estimateMissclassificationCosts(project)).entrySet()){ cp.add(new DatumValue(e.getKey().getName(), e.getValue())); } setResult(cp); } } static public class GetDataCostSummary extends JobCommand<Map<String, Object>, NominalProject> { public GetDataCostSummary(){ super(false); } @Override protected void realExecute() throws Exception { HashMap<String, Object> ret = new HashMap<String, Object>(); for (String s : new String[] {"ExpectedCost", "MinCost", "MaxLikelihood"}){ ILabelProbabilityDistributionCostCalculator lpdcc = LabelProbabilityDistributionCostCalculators.get(s); DecisionEngine de = new DecisionEngine(lpdcc, null); ret.put(s, MathHelpers.getAverage(de.estimateMissclassificationCosts(project))); } ret.put("Spammer", getMinSpammerCost(project)); setResult(ret); } } static public class GetDataQualitySummary extends JobCommand<Map<String, Object>, NominalProject> { public GetDataQualitySummary(){ super(false); } @Override protected void realExecute() throws Exception { HashMap<String, Object> ret = new HashMap<String, Object>(); for (String s : new String[] {"ExpectedCost", "MinCost", "MaxLikelihood"}){ ILabelProbabilityDistributionCostCalculator lpdcc = LabelProbabilityDistributionCostCalculators.get(s); DecisionEngine de = new DecisionEngine(lpdcc, null); ret.put(s, Quality.fromCost(project, MathHelpers.getAverage(de.estimateMissclassificationCosts(project)))); } setResult(ret); } } static public class GetPredictionZip extends com.datascience.core.commands.PredictionCommands.AbstractGetPredictionZip<NominalProject> { public GetPredictionZip(String path){ super(path); HashMap<String, GetStatistics> files = new HashMap<String, GetStatistics>(); files.put("prediction.tsv", new GetDataPrediction()); files.put("workers_quality.tsv", new GetWorkersQuality()); setStatisticsFilesMap(files); } class GetDataPrediction extends GetStatistics { @Override public List<List<Object>> call(){ String[] lda = new String[]{"MinCost"}; // TROIA-393 {"MaxLikelihood", "MinCost"} List<List<Object>> ret = new ArrayList<List<Object>>(); List<Object> header = new ArrayList<Object>(); header.add(""); for (String la : lda) header.add(la); ret.add(header); for (LObject<String> d : project.getData().getObjects()){ List<Object> line = new ArrayList<Object>(); line.add(d.getName()); for (String la : lda){ DecisionEngine decisionEngine = new DecisionEngine( null, ObjectLabelDecisionAlgorithms.get(la)); line.add(decisionEngine.predictLabel(project, d)); } ret.add(line); } return ret; } } class GetWorkersQuality extends GetStatistics { public List<List<Object>> call(){ String[] lca = new String[]{"MinCost", "ExpectedCost"}; // TROIA-393 {"MaxLikelihood", "MinCost", "ExpectedCost"} List<List<Object>> ret = new ArrayList<List<Object>>(); List<Object> header = new ArrayList<Object>(); header.add(""); for (String lc : lca) header.add(lc); ret.add(header); for (Worker w : project.getData().getWorkers()){ List<Object> line = new ArrayList<Object>(); line.add(w.getName()); for (String lc : lca){ WorkerQualityCalculator wqc = new WorkerEstimator( LabelProbabilityDistributionCostCalculators.get(lc)); line.add(Quality.fromCost(project, wqc.getCost(project, w))); } ret.add(line); } return ret; } } } }