package quickml.supervised.crossValidation.attributeImportance; import com.google.common.collect.Maps; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierLossFunction; import quickml.supervised.crossValidation.PredictionMapResults; import java.util.List; import java.util.Map; import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; public class LossFunctionTracker { private static final Logger logger = LoggerFactory.getLogger(LossFunctionTracker.class); // Map of loss function name to the running loss for that function private Map<String, RunningWeight> functionLossMap = Maps.newHashMap(); private ClassifierLossFunction primaryLossFunction; public LossFunctionTracker(List<ClassifierLossFunction> lossFunctions) { this(getSecondaryLossFunctions(lossFunctions), lossFunctions.get(0)); } public LossFunctionTracker(List<ClassifierLossFunction> lossFunctions, ClassifierLossFunction primaryLossFunction) { this.primaryLossFunction = primaryLossFunction; for (ClassifierLossFunction lossFunction : lossFunctions) { functionLossMap.put(lossFunction.getName(), new RunningWeight(lossFunction)); } functionLossMap.put(primaryLossFunction.getName(), new RunningWeight(primaryLossFunction)); } public void updateLosses(PredictionMapResults results) { for (RunningWeight runningWeight : functionLossMap.values()) { runningWeight.updateLosses(results); } } public Set<String> lossFunctionNames() { return functionLossMap.keySet(); } public double getPrimaryLoss() { return getLossForFunction(primaryLossFunction.getName()); } public double getLossForFunction(String lossFunction) { return functionLossMap.get(lossFunction).loss(); } public void logLosses() { for (String functionName : functionLossMap.keySet()) { logger.info("Log function - {} - Loss - {}", functionName, functionLossMap.get(functionName).loss() ); } } private static List<ClassifierLossFunction> getSecondaryLossFunctions(List<ClassifierLossFunction> lossFunctions) { checkArgument(lossFunctions.size() > 0, "There must be at least one loss function supplied"); return lossFunctions.subList(1, lossFunctions.size()); } class RunningWeight implements Comparable<RunningWeight> { double runningLoss = 0; double runningWeightOfValidationSet = 0; private ClassifierLossFunction lossFunction; public RunningWeight(ClassifierLossFunction lossFunction) { this.lossFunction = lossFunction; } public void updateLosses(PredictionMapResults results) { runningLoss += lossFunction.getLoss(results) * results.totalWeight(); runningWeightOfValidationSet += results.totalWeight(); } public double loss() { return runningWeightOfValidationSet > 0 ? runningLoss / runningWeightOfValidationSet : 0; } @Override public int compareTo(RunningWeight o) { return Double.compare(loss(), o.loss()); } } }