package com.datascience.core.nominal.decision; import java.util.Collection; import java.util.HashMap; import java.util.Map; import com.datascience.core.base.LObject; import com.datascience.core.nominal.NominalProject; import com.datascience.utils.CostMatrix; /** * Object of this class combines way of calculating labels probability distribution and choosing decision based * on this info. It also estimates cost that is related to this decision. * * @author konrad */ public class DecisionEngine { ILabelProbabilityDistributionCostCalculator labelProbabilityDistributionCostCalculator; IObjectLabelDecisionAlgorithm objectLabelDecisionAlgorithm; public DecisionEngine(ILabelProbabilityDistributionCostCalculator labelProbabilityDistributionCostCalculator, IObjectLabelDecisionAlgorithm objectLabelDecisionAlgorithm){ this.labelProbabilityDistributionCostCalculator = labelProbabilityDistributionCostCalculator; this.objectLabelDecisionAlgorithm = objectLabelDecisionAlgorithm; } public Map<String, Double> getPD(LObject<String> datum, NominalProject project){ return project.getObjectResults(datum).getCategoryProbabilites(); } public String predictLabel(NominalProject project, LObject<String> datum, CostMatrix<String> cm) { return objectLabelDecisionAlgorithm.predictLabel(getPD(datum, project), cm); } public double estimateMissclassificationCost(NominalProject project, LObject<String> datum, CostMatrix<String> cm) { return labelProbabilityDistributionCostCalculator.predictedLabelCost( getPD(datum, project), cm); } public double estimateMissclassificationCost(NominalProject project, Map<String, Double> pd) { return labelProbabilityDistributionCostCalculator.predictedLabelCost( pd, project.getData().getCostMatrix()); } public String predictLabel(NominalProject project, LObject<String> datum) { return predictLabel(project, datum, project.getData().getCostMatrix()); } public double estimateMissclassificationCost(NominalProject project, LObject<String> datum) { return estimateMissclassificationCost(project, datum, project.getData().getCostMatrix()); } public Map<LObject<String>, String> predictLabels(NominalProject project){ Collection<LObject<String>> datums = project.getData().getObjects(); CostMatrix<String> cm = project.getData().getCostMatrix(); Map<LObject<String>, String> ret = new HashMap<LObject<String>, String>(); for (LObject<String> e: datums) { ret.put(e, predictLabel(project, e, cm)); } return ret; } public Map<LObject<String>, Double> estimateMissclassificationCosts(NominalProject project){ Collection<LObject<String>> datums = project.getData().getObjects(); CostMatrix<String> cm = project.getData().getCostMatrix(); Map<LObject<String>, Double> ret = new HashMap<LObject<String>, Double>(); for (LObject<String> e: datums) { ret.put(e, estimateMissclassificationCost(project, e, cm)); } return ret; } }