package quickml.supervised.crossValidation.attributeImportance;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
public class RegAttributeLossSummary {
private List<RegAttributeLossTracker> lossTrackers;
public RegAttributeLossSummary(List<RegAttributeLossTracker> lossTrackers) {
this.lossTrackers = lossTrackers;
sortTrackers(lossTrackers);
}
public List<String> getOptimalAttributes() {
return lossTrackers.get(0).getOrderedAttributes();
}
/**
* Return the list that is closest in getSize to the desired getSize
* If two sets are equidistant from the desired getSize, return the one with the lowest loss
* @param n
* @return
*/
public List<AttributeWithLoss> getMaximalSet(int n) {
RegAttributeLossTracker optimalSet = lossTrackers.get(0);
for (RegAttributeLossTracker lossTracker : lossTrackers) {
if (isCloserToOptimalSize(n, optimalSet, lossTracker)
|| (equallyCloseToOptimalSet(n, optimalSet, lossTracker) && lossIsBetter(optimalSet, lossTracker))) {
optimalSet = lossTracker;
}
}
return optimalSet.getOrderedLosses();
}
private boolean lossIsBetter(RegAttributeLossTracker optimalSet, RegAttributeLossTracker lossTracker) {
return lossTracker.getOverallLoss() < optimalSet.getOverallLoss();
}
private boolean equallyCloseToOptimalSet(int n, RegAttributeLossTracker optimalSet, RegAttributeLossTracker lossTracker) {
return (Math.abs(lossTracker.getOrderedAttributes().size() - n) == Math.abs(optimalSet.getOrderedAttributes().size() - n));
}
private boolean isCloserToOptimalSize(int n, RegAttributeLossTracker optimalSet, RegAttributeLossTracker lossTracker) {
return Math.abs(lossTracker.getOrderedAttributes().size() - n) < Math.abs(optimalSet.getOrderedAttributes().size() - n);
}
private void sortTrackers(List<RegAttributeLossTracker> lossTrackers) {
Collections.sort(lossTrackers, new Comparator<RegAttributeLossTracker>() {
@Override
public int compare(RegAttributeLossTracker o1, RegAttributeLossTracker o2) {
return Double.compare(o1.getOverallLoss(), o2.getOverallLoss());
}
});
}
}