package quickml.supervised.crossValidation.attributeImportance;
import com.google.common.collect.Lists;
import org.junit.Before;
import org.junit.Test;
import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierLossFunction;
import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierRMSELossFunction;
import java.util.List;
import java.util.Set;
import static com.google.common.collect.Lists.newArrayList;
import static com.google.common.collect.Sets.newHashSet;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class AttributeLossSummaryTest {
private AttributeLossSummary attributeLossSummary;
private ClassifierLossFunction lossFunction = new ClassifierRMSELossFunction();
private AttributeLossTracker lossTracker1;
private AttributeLossTracker lossTracker2;
private AttributeLossTracker lossTracker3;
@Before
public void setUp() throws Exception {
lossTracker1 = new MockAttributeLossTracker(0.6, newHashSet("a", "b", "c"));
lossTracker2 = new MockAttributeLossTracker(0.2, newHashSet("a", "b"));
lossTracker3 = new MockAttributeLossTracker(0.4, newHashSet("a", "b", "c", "d", "e"));
attributeLossSummary = new AttributeLossSummary(newArrayList(lossTracker1, lossTracker2, lossTracker3));
}
@Test
public void testGetOptimalAttributesReturnsTrackerWithLowestLoss() throws Exception {
assertEquals(2, attributeLossSummary.getOptimalAttributes().size());
assertTrue(attributeLossSummary.getOptimalAttributes().contains("a"));
assertTrue(attributeLossSummary.getOptimalAttributes().contains("b"));
}
@Test
public void testgetMaximalAttributes() throws Exception {
assertEquals(5, attributeLossSummary.getMaximalSet(5).size());
assertEquals(5, attributeLossSummary.getMaximalSet(4).size());
assertEquals(2, attributeLossSummary.getMaximalSet(2).size());
}
// Simple mock class to override loss and attributes
class MockAttributeLossTracker extends AttributeLossTracker {
private double loss;
private Set<String> attributes;
public MockAttributeLossTracker(double loss, Set<String> attributes) {
super(attributes, newArrayList(lossFunction), lossFunction);
this.loss = loss;
this.attributes = attributes;
}
@Override
public double getOverallLoss() {
return loss;
}
@Override
public List<String> getOrderedAttributes() {
return Lists.newArrayList(attributes);
}
}
}