package quickml.supervised.predictiveModelOptimizer; import com.google.common.collect.Maps; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; import quickml.supervised.crossValidation.ClassifierLossChecker; import quickml.supervised.crossValidation.SimpleCrossValidator; import quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders.FixedOrderRecommender; import java.util.HashMap; import java.util.Map; import java.util.TreeMap; import static org.junit.Assert.assertEquals; import static org.mockito.Matchers.anyMap; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.when; import static org.mockito.MockitoAnnotations.initMocks; public class PredictiveModelOptimizerTest { @Mock SimpleCrossValidator mockSimpleCrossValidator; @Mock ClassifierLossChecker mockLossChecker; private PredictiveModelOptimizer modelOptimizer; private HashMap<String, Object> bestConfig = Maps.newHashMap(); private HashMap<String, Object> secondBestConfig = Maps.newHashMap(); private HashMap<String, Object> thirdBestConfig = Maps.newHashMap(); @Before public void setUp() throws Exception { initMocks(this); // Use a tree map for deteminisic order Map<String, FixedOrderRecommender> fields = new TreeMap<>(); fields.put("treeDepth", new FixedOrderRecommender(1, 2, 3, 4, 5)); fields.put("penalize_splits", new FixedOrderRecommender(true, false)); fields.put("scorerFactory", new FixedOrderRecommender("A", "B", "C")); modelOptimizer = new PredictiveModelOptimizer(fields, mockSimpleCrossValidator, 10); } @Test public void testFindSimpleBestConfig() throws Exception { // Fields are checked in the following order - penalize_splits, scorerFactory, treeDepth thirdBestConfig = createMap(1, false, "A"); secondBestConfig = createMap(1, false, "C"); bestConfig = createMap(5, false, "C"); when(mockSimpleCrossValidator.getLossForModel(anyMap())).thenReturn(0.5); when(mockSimpleCrossValidator.getLossForModel(eq(thirdBestConfig))).thenReturn(0.4); when(mockSimpleCrossValidator.getLossForModel(eq(secondBestConfig))).thenReturn(0.2); when(mockSimpleCrossValidator.getLossForModel(eq(bestConfig))).thenReturn(0.1); assertEquals(bestConfig, modelOptimizer.determineOptimalConfig()); } private HashMap<String, Object> createMap(int treeDepth, boolean penalizeSplits, String scorer) { HashMap<String, Object> map = new HashMap<>(); map.put("treeDepth", treeDepth); map.put("penalize_splits", penalizeSplits); map.put("scorerFactory", scorer); return map; } }