package quickml.supervised.crossValidation.attributeImportance; import com.google.common.collect.Maps; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import quickml.supervised.crossValidation.PredictionMapResults; import quickml.supervised.crossValidation.lossfunctions.LabelPredictionWeight; import quickml.supervised.crossValidation.lossfunctions.regressionLossFunctions.RegressionLossFunction; import java.util.List; import java.util.Map; import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; public class RegLossFunctionTracker { private static final Logger logger = LoggerFactory.getLogger(RegLossFunctionTracker.class); // Map of loss function name to the running loss for that function private Map<String, RunningWeight> functionLossMap = Maps.newHashMap(); private RegressionLossFunction primaryLossFunction; public RegLossFunctionTracker(List<RegressionLossFunction> lossFunctions) { this(getSecondaryLossFunctions(lossFunctions), lossFunctions.get(0)); } public RegLossFunctionTracker(List<RegressionLossFunction> lossFunctions, RegressionLossFunction primaryLossFunction) { this.primaryLossFunction = primaryLossFunction; for (RegressionLossFunction lossFunction : lossFunctions) { functionLossMap.put(lossFunction.getName(), new RunningWeight(lossFunction)); } functionLossMap.put(primaryLossFunction.getName(), new RunningWeight(primaryLossFunction)); } public void updateLosses(List<LabelPredictionWeight<Double, Double>> results) { for (RunningWeight runningWeight : functionLossMap.values()) { runningWeight.updateLosses(results); } } public Set<String> lossFunctionNames() { return functionLossMap.keySet(); } public double getPrimaryLoss() { return getLossForFunction(primaryLossFunction.getName()); } public double getLossForFunction(String lossFunction) { return functionLossMap.get(lossFunction).loss(); } public void logLosses() { for (String functionName : functionLossMap.keySet()) { logger.info("Log function - {} - Loss - {}", functionName, functionLossMap.get(functionName).loss() ); } } private static List<RegressionLossFunction> getSecondaryLossFunctions(List<RegressionLossFunction> lossFunctions) { checkArgument(lossFunctions.size() > 0, "There must be at least one loss function supplied"); return lossFunctions.subList(1, lossFunctions.size()); } class RunningWeight implements Comparable<RunningWeight> { double runningLoss = 0; double runningWeightOfValidationSet = 0; private RegressionLossFunction lossFunction; public RunningWeight(RegressionLossFunction lossFunction) { this.lossFunction = lossFunction; } public void updateLosses(List<LabelPredictionWeight<Double, Double>> results) { double totalWeight=0; for (LabelPredictionWeight<Double, Double> lpw : results) { totalWeight+=lpw.getWeight(); } runningLoss += lossFunction.getLoss(results) * totalWeight; runningWeightOfValidationSet += totalWeight; // logger.info("loss for no missing attributes: {}",loss()); } public double loss() { return runningWeightOfValidationSet > 0 ? runningLoss / runningWeightOfValidationSet : 0; } @Override public int compareTo(RunningWeight o) { return Double.compare(loss(), o.loss()); } } }