package hex.genmodel.algos.kmeans;
import com.google.common.io.ByteStreams;
import hex.genmodel.MojoModel;
import hex.genmodel.MojoReaderBackend;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.prediction.ClusteringModelPrediction;
import org.junit.Before;
import org.junit.Test;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import static org.junit.Assert.*;
public class KMeansMojoModelTest {
private MojoModel _mojo;
private double[][] _rows;
private RowData[] _rowData;
@Before
public void setup() throws IOException {
_mojo = KMeansMojoReader.readFrom(new KMeansMojoModelTest.ClasspathReaderBackend());
_rows = new double[][] {
new double[]{2.0, 1.0, 22.0, 1.0, 0.0},
new double[]{2.0, 1.0, 2.0, 3.0, 1.0},
new double[]{2.0, 0.0, 27.0, 0.0, 2.0}
};
_rowData = new RowData[_rows.length];
for (int i = 0; i < _rows.length; i++)
_rowData[i] = toRowData(_mojo, _rows[i]);
}
@Test
public void testPredict() throws Exception {
EasyPredictModelWrapper wrapper = new EasyPredictModelWrapper(_mojo);
for (int i = 0; i < 3; i++) {
// test easy-predict
ClusteringModelPrediction p = (ClusteringModelPrediction) wrapper.predict(_rowData[i]);
assertEquals(i, p.cluster);
// test score0
double[] preds = new double[1];
_mojo.score0(_rows[i], preds);
assertEquals(i, preds[0], 0.0);
}
}
private static RowData toRowData(MojoModel mojo, double[] row) {
RowData rowData = new RowData();
for (String name : mojo._names) {
int idx = mojo.getColIdx(name);
String[] domain = mojo.getDomainValues(idx);
if (domain != null)
rowData.put(name, domain[(int) row[idx]]);
else
rowData.put(name, row[idx]);
}
return rowData;
}
private static class ClasspathReaderBackend implements MojoReaderBackend {
@Override
public BufferedReader getTextFile(String filename) throws IOException {
InputStream is = KMeansMojoModelTest.class.getResourceAsStream(filename);
return new BufferedReader(new InputStreamReader(is));
}
@Override
public byte[] getBinaryFile(String filename) throws IOException {
InputStream is = KMeansMojoModelTest.class.getResourceAsStream(filename);
return ByteStreams.toByteArray(is);
}
@Override
public boolean exists(String filename) {
return true;
}
}
}