package quickml.supervised.crossValidation.attributeImportance; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import quickml.supervised.crossValidation.lossfunctions.LabelPredictionWeight; import quickml.supervised.crossValidation.lossfunctions.regressionLossFunctions.RegressionLossFunction; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; /** * Take a list of loss functions and keep a running total of the loss for each loss function per attribute */ public class RegAttributeLossTracker { private static final Logger logger = LoggerFactory.getLogger(RegAttributeLossTracker.class); private final RegLossFunctionTracker allAttributeLossTracker; private Map<String, RegLossFunctionTracker> attributeLossMap = Maps.newHashMap(); public RegAttributeLossTracker(Set<String> attributes, List<RegressionLossFunction> lossFunctions, RegressionLossFunction primaryLossFunction) { for (String attribute : attributes) { attributeLossMap.put(attribute, new RegLossFunctionTracker(lossFunctions, primaryLossFunction)); } allAttributeLossTracker = new RegLossFunctionTracker(lossFunctions, primaryLossFunction); } public void updateAttribute(String attribute, List<LabelPredictionWeight<Double, Double>> results) { attributeLossMap.get(attribute).updateLosses(results); } public void noMissingAttributeLoss(List<LabelPredictionWeight<Double, Double>> predictionMapResults) { allAttributeLossTracker.updateLosses(predictionMapResults); } public List<String> getOrderedAttributes() { List<String> attributes = Lists.newArrayList(); for (AttributeWithLoss attributeWithLoss : getOrderedLosses()) { attributes.add(attributeWithLoss.getAttribute()); } return attributes; } public List<AttributeWithLoss> getOrderedLosses() { List<AttributeWithLoss> list = Lists.newArrayList(); for (String attribute : attributeLossMap.keySet()) { list.add(new AttributeWithLoss(attribute, attributeLossMap.get(attribute).getPrimaryLoss())); } Collections.sort(list); return list; } public double getOverallLoss() { return allAttributeLossTracker.getPrimaryLoss(); } public void logResults() { logger.info("----- Attribute Loss Tracker - Number of attributes {} ----", attributeLossMap.keySet().size()); for (String attribute : getOrderedAttributes()) { logger.info("Attribute {}", attribute); attributeLossMap.get(attribute).logLosses(); } } }