package hex.genmodel.algos.klime; import com.google.common.io.ByteStreams; import hex.genmodel.MojoModel; import hex.genmodel.MojoReaderBackend; import hex.genmodel.algos.glm.GlmMojoModel; import hex.genmodel.easy.EasyPredictModelWrapper; import hex.genmodel.easy.RowData; import hex.genmodel.easy.prediction.KLimeModelPrediction; import org.junit.Before; import org.junit.Test; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import static org.junit.Assert.*; public class KLimeMojoModelTest { private KLimeMojoModel _mojo; private double[][] _rows; private RowData[] _rowData; @Before public void setup() throws IOException { _mojo = (KLimeMojoModel) KLimeMojoReader.readFrom(new KLimeMojoModelTest.ClasspathReaderBackend()); _rows = new double[][] { new double[]{2.0, 1.0, 22.0, 1.0, 0.0}, new double[]{2.0, 1.0, 2.0, 3.0, 1.0}, new double[]{2.0, 0.0, 27.0, 0.0, 2.0} }; _rowData = new RowData[_rows.length]; for (int i = 0; i < _rows.length; i++) _rowData[i] = toRowData(_mojo, _rows[i]); } @Test public void testScore0() throws Exception { EasyPredictModelWrapper wrapper = new EasyPredictModelWrapper(_mojo); double[] preds = new double[7]; KLimeModelPrediction p; // prediction is made by a cluster-local GLM model p = wrapper.predictKLime(_rowData[0]); checkPrediction(p, _mojo); _mojo.score0(_rows[0], preds); assertEquals(preds[0], 0.127, 0.001); assertEquals(preds[1], 0, 0.0); checkPrediction(preds, _mojo); // data point belongs to cluster 1 but prediction is made by a global model p = wrapper.predictKLime(_rowData[1]); checkPrediction(p, _mojo); _mojo.score0(_rows[1], preds); assertEquals(preds[0], 0.141, 0.001); assertEquals(preds[1], 1, 0.0); checkPrediction(preds, _mojo); // data point belongs to cluster 2 but prediction is made by a global model p = wrapper.predictKLime(_rowData[2]); checkPrediction(p, _mojo); _mojo.score0(_rows[2], preds); assertEquals(preds[0], 0.596, 0.001); assertEquals(preds[1], 2, 0.0); checkPrediction(preds, _mojo); } private void checkPrediction(double[] preds, KLimeMojoModel mojo) { GlmMojoModel m = mojo.getRegressionModel((int) preds[1]); double p = m.getIntercept(); for (int i = 2; i < preds.length; i++) p += preds[i]; assertEquals(preds[0], p, 1e-6); } private void checkPrediction(KLimeModelPrediction pred, KLimeMojoModel mojo) { GlmMojoModel m = mojo.getRegressionModel(pred.cluster); double p = m.getIntercept(); for (int i = 0; i < pred.reasonCodes.length; i++) p += pred.reasonCodes[i]; assertEquals(pred.value, p, 1e-6); } private static RowData toRowData(MojoModel mojo, double[] row) { RowData rowData = new RowData(); for (String name : mojo._names) { int idx = mojo.getColIdx(name); if (idx >= row.length) continue; String[] domain = mojo.getDomainValues(idx); if (domain != null) rowData.put(name, domain[(int) row[idx]]); else rowData.put(name, row[idx]); } return rowData; } private static class ClasspathReaderBackend implements MojoReaderBackend { @Override public BufferedReader getTextFile(String filename) throws IOException { InputStream is = KLimeMojoModelTest.class.getResourceAsStream(filename); return new BufferedReader(new InputStreamReader(is)); } @Override public byte[] getBinaryFile(String filename) throws IOException { InputStream is = KLimeMojoModelTest.class.getResourceAsStream(filename); return ByteStreams.toByteArray(is); } @Override public boolean exists(String filename) { return true; } } }