package com.datascience.core.nominal.decision; import com.datascience.utils.CostMatrix; import com.datascience.utils.ProbabilityDistributions; import com.google.common.base.Strings; import java.util.Map; /** * * @author konrad */ public class ObjectLabelDecisionAlgorithms { public static class MaxProbabilityDecisionAlgorithm implements IObjectLabelDecisionAlgorithm { @Override public String predictLabel(Map<String, Double> labelProbabilities, CostMatrix<String> costMatrix) { String mostProbableLabel = null; double mostProbableLabelprob = Double.NEGATIVE_INFINITY; for (Map.Entry<String, Double> entry: labelProbabilities.entrySet()) { if (entry.getValue() > mostProbableLabelprob) { mostProbableLabel = entry.getKey(); mostProbableLabelprob = entry.getValue(); } } return mostProbableLabel; } } public static class MinCostDecisionAlgorithm implements IObjectLabelDecisionAlgorithm { @Override public String predictLabel(Map<String, Double> labelProbabilities, CostMatrix<String> costMatrix) { String minCostLabel = null; double minCostLabelCost = Double.POSITIVE_INFINITY; for (String label: labelProbabilities.keySet()) { double cost = ProbabilityDistributions.calculateLabelCost(label, labelProbabilities, costMatrix); if (cost < minCostLabelCost) { minCostLabel = label; minCostLabelCost = cost; } } return minCostLabel; } } public static IObjectLabelDecisionAlgorithm get(String algorithmName){ if (Strings.isNullOrEmpty(algorithmName)) { algorithmName = "MaxLikelihood"; } algorithmName = algorithmName.toUpperCase(); if ("MINCOST".equals(algorithmName)) { return new MinCostDecisionAlgorithm(); } if ("MAXLIKELIHOOD".equals(algorithmName)) { return new MaxProbabilityDecisionAlgorithm(); } throw new IllegalArgumentException("Unknown decision algorithm: " + algorithmName); } }