package com.datascience.core.nominal.decision; import com.datascience.utils.CostMatrix; import com.datascience.utils.ProbabilityDistributions; import com.google.common.base.Strings; import java.util.Map; import java.util.Set; /** * * @author Artur & Konrad */ public class LabelProbabilityDistributionCostCalculators { public static class SelectedLabelBased implements ILabelProbabilityDistributionCostCalculator { private IObjectLabelDecisionAlgorithm labelChooser; public SelectedLabelBased(IObjectLabelDecisionAlgorithm labelChooser){ this.labelChooser = labelChooser; } @Override public Double predictedLabelCost(Map<String, Double> labelProbabilities, CostMatrix<String> costMatrix) { String choosenLabel = labelChooser.predictLabel(labelProbabilities, costMatrix); return ProbabilityDistributions.calculateLabelCost(choosenLabel, labelProbabilities, costMatrix); } } public static class ExpectedCostAlgorithm implements ILabelProbabilityDistributionCostCalculator { @Override public Double predictedLabelCost(Map<String, Double> labelProbabilities, CostMatrix<String> costMatrix) { Set<Map.Entry<String, Double>> entries = labelProbabilities.entrySet(); double cost = 0.; for (Map.Entry<String, Double> entry1: entries) { for (Map.Entry<String, Double> entry2: entries) { double errCost = costMatrix.getCost(entry1.getKey(), entry2.getKey()); cost += entry1.getValue() * entry2.getValue() * errCost; } } return cost; } } public static ILabelProbabilityDistributionCostCalculator get(String method){ if (Strings.isNullOrEmpty(method)) { method = "ExpectedCost"; } method = method.toUpperCase(); if ("EXPECTEDCOST".equals(method)) { return new ExpectedCostAlgorithm(); } try { IObjectLabelDecisionAlgorithm olda = ObjectLabelDecisionAlgorithms.get(method); return new SelectedLabelBased(olda); } catch (IllegalArgumentException ex) { throw new IllegalArgumentException( "Unknown cost calculation method: " + method); } } }