package hex.genmodel.easy; import hex.ModelCategory; import hex.genmodel.GenModel; import hex.genmodel.MojoModel; import hex.genmodel.algos.word2vec.WordEmbeddingModel; import hex.genmodel.easy.exception.PredictUnknownCategoricalLevelException; import hex.genmodel.easy.prediction.BinomialModelPrediction; import hex.genmodel.easy.prediction.SortedClassProbability; import hex.genmodel.easy.prediction.Word2VecPrediction; import org.junit.Assert; import org.junit.Test; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; public class EasyPredictModelWrapperTest { private static class MyModel extends GenModel { MyModel(String[] names, String[][] domains) { super(names, domains); } @Override public int nclasses() { return 2; } @Override public boolean isSupervised() { return true; } @Override public double[] score0(double[] data, double[] preds) { Assert.assertEquals(preds.length, 3); preds[0] = 0; preds[1] = 1.0; preds[2] = 0.0; return preds; } @Override public ModelCategory getModelCategory() { return ModelCategory.Binomial; } @Override public String getUUID() { return null; } } private static MyModel makeModel() { String[] names = { "C1", "C2", "RESPONSE"}; String[][] domains = { {"c1level1", "c1level2"}, {"c2level1", "c2level2", "c2level3"}, {"NO", "YES"} }; return new MyModel(names, domains); } @Test public void testUnknownCategoricalLevels() throws Exception { MyModel rawModel = makeModel(); EasyPredictModelWrapper m = new EasyPredictModelWrapper(rawModel); { RowData row = new RowData(); row.put("C1", "c1level1"); try { m.predictBinomial(row); } catch (PredictUnknownCategoricalLevelException e) { Assert.fail("Caught exception but should not have"); } ConcurrentHashMap<String, AtomicLong> unknown = m.getUnknownCategoricalLevelsSeenPerColumn(); long total = 0; for (AtomicLong l : unknown.values()) { total += l.get(); } Assert.assertEquals(total, 0); } { RowData row = new RowData(); row.put("C1", "c1level1"); row.put("C2", "unknownLevel"); boolean caught = false; try { m.predictBinomial(row); } catch (PredictUnknownCategoricalLevelException e) { caught = true; } Assert.assertEquals(caught, true); ConcurrentHashMap<String, AtomicLong> unknown = m.getUnknownCategoricalLevelsSeenPerColumn(); long total = 0; for (AtomicLong l : unknown.values()) { total += l.get(); } Assert.assertEquals(total, 0); } m = new EasyPredictModelWrapper(new EasyPredictModelWrapper.Config() .setModel(rawModel) .setConvertUnknownCategoricalLevelsToNa(true) .setConvertInvalidNumbersToNa(true)); { RowData row0 = new RowData(); m.predict(row0); Assert.assertEquals(m.getTotalUnknownCategoricalLevelsSeen(), 0); RowData row1 = new RowData(); row1.put("C1", "c1level1"); row1.put("C2", "unknownLevel"); m.predictBinomial(row1); Assert.assertEquals(m.getTotalUnknownCategoricalLevelsSeen(), 1); RowData row2 = new RowData(); row2.put("C1", "c1level1"); row2.put("C2", "c2level3"); m.predictBinomial(row2); Assert.assertEquals(m.getTotalUnknownCategoricalLevelsSeen(), 1); RowData row3 = new RowData(); row3.put("C1", "c1level1"); row3.put("unknownColumn", "unknownLevel"); m.predictBinomial(row3); Assert.assertEquals(m.getTotalUnknownCategoricalLevelsSeen(), 1); m.predictBinomial(row1); m.predictBinomial(row1); Assert.assertEquals(m.getTotalUnknownCategoricalLevelsSeen(), 3); RowData row4 = new RowData(); row4.put("C1", "unknownLevel"); m.predictBinomial(row4); Assert.assertEquals(m.getTotalUnknownCategoricalLevelsSeen(), 4); Assert.assertEquals(m.getUnknownCategoricalLevelsSeenPerColumn().get("C1").get(), 1); Assert.assertEquals(m.getUnknownCategoricalLevelsSeenPerColumn().get("C2").get(), 3); } } @Test public void testSortedClassProbability() throws Exception { MyModel rawModel = makeModel(); EasyPredictModelWrapper m = new EasyPredictModelWrapper(rawModel); { RowData row = new RowData(); row.put("C1", "c1level1"); BinomialModelPrediction p = m.predictBinomial(row); SortedClassProbability[] arr = m.sortByDescendingClassProbability(p); Assert.assertEquals(arr[0].name, "NO"); Assert.assertEquals(arr[0].probability, 1.0, 0.001); Assert.assertEquals(arr[1].name, "YES"); Assert.assertEquals(arr[1].probability, 0.0, 0.001); } } @Test public void testWordEmbeddingModel() throws Exception { MyWordEmbeddingModel rawModel = new MyWordEmbeddingModel(); EasyPredictModelWrapper m = new EasyPredictModelWrapper(rawModel); RowData row = new RowData(); row.put("C0", -1); // should be ignored row.put("C1", "0.9,0.1"); row.put("C2", "0.1,0.9"); row.put("C3", "NA"); Word2VecPrediction p = m.predictWord2Vec(row); Assert.assertFalse(p.wordEmbeddings.containsKey("C0")); Assert.assertArrayEquals(new float[]{0.9f, 0.1f}, p.wordEmbeddings.get("C1"), 0.0001f); Assert.assertArrayEquals(new float[]{0.1f, 0.9f}, p.wordEmbeddings.get("C2"), 0.0001f); Assert.assertTrue(p.wordEmbeddings.containsKey("C3")); Assert.assertNull(p.wordEmbeddings.get("C3")); } private static class MyWordEmbeddingModel extends MojoModel implements WordEmbeddingModel { public MyWordEmbeddingModel() { super(new String[0], new String[0][]); } @Override public int getVecSize() { return 2; } @Override public float[] transform0(String word, float[] output) { if (word.equals("NA")) return null; String[] words = word.split(","); for (int i = 0; i < words.length; i++) output[i] = Float.valueOf(words[i]); return output; } @Override public double[] score0(double[] row, double[] preds) { throw new IllegalStateException("Should never be called"); } @Override public ModelCategory getModelCategory() { return ModelCategory.WordEmbedding; } } }