package hex.genmodel.algos.gbm;
import com.google.common.io.ByteStreams;
import hex.genmodel.ModelMojoReader;
import hex.genmodel.MojoReaderBackend;
import hex.genmodel.algos.word2vec.Word2VecMojoModelTest;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import org.junit.Test;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.Arrays;
import static org.junit.Assert.*;
public class GbmMojoModelTest {
@Test
public void testPredict() throws Exception {
GbmMojoModel mojo = (GbmMojoModel) ModelMojoReader.readFrom(new ClasspathReaderBackend());
assertNotNull(mojo);
EasyPredictModelWrapper wrapper = new EasyPredictModelWrapper(mojo);
BinomialModelPrediction pred = (BinomialModelPrediction) wrapper.predict(new RowData() {{
put("SegSumT", 18.7);
put("SegTSeas", 1.51);
put("SegLowFlow", 1.003);
put("DSDist", 132.53);
put("DSMaxSlope", 1.15);
put("USAvgT", 0.2);
put("USRainDays", 1.153);
put("USSlope", 8.3);
put("USNative", 0.34);
put("DSDam", 0.0);
put("Method", "electric");
}});
assertEquals(1, pred.labelIndex);
assertEquals("1", pred.label);
assertArrayEquals(new double[]{0.5416688, 0.4583312}, pred.classProbabilities, 1e-5);
assertArrayEquals(new double[]{0.3920402, 0.6079598}, pred.calibratedClassProbabilities, 1e-5);
}
private static class ClasspathReaderBackend implements MojoReaderBackend {
@Override
public BufferedReader getTextFile(String filename) throws IOException {
InputStream is = GbmMojoModelTest.class.getResourceAsStream("calibrated/" + filename);
return new BufferedReader(new InputStreamReader(is));
}
@Override
public byte[] getBinaryFile(String filename) throws IOException {
InputStream is = GbmMojoModelTest.class.getResourceAsStream("calibrated/" + filename);
return ByteStreams.toByteArray(is);
}
@Override
public boolean exists(String name) {
return true;
}
}
}