package hex.genmodel.algos.word2vec;
import org.junit.Test;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import static org.junit.Assert.*;
public class Word2VecMojoReaderTest {
@Test
public void readModelData() throws Exception {
TestedWord2VecMojoReader reader = new TestedWord2VecMojoReader();
reader.readModelData();
Word2VecMojoModel model = reader.getModel();
assertArrayEquals(new float[]{0.0f, 1.0f}, model.transform0("A", new float[2]), 0.0001f);
assertArrayEquals(new float[]{2.0f, 3.0f}, model.transform0("B", new float[2]), 0.0001f);
assertArrayEquals(new float[]{4.0f, 5.0f}, model.transform0("C", new float[2]), 0.0001f);
}
private static class TestedWord2VecMojoReader extends Word2VecMojoReader {
private TestedWord2VecMojoReader() {
_model = new Word2VecMojoModel(new String[0], new String[0][]);
}
@Override
@SuppressWarnings("unchecked")
protected <T> T readkv(String key, T defVal) {
Object result = null;
if ("vocab_size".equals(key))
result = 3;
else if ("vec_size".equals(key))
result = 2;
return (T) result;
}
@Override
protected byte[] readblob(String name) throws IOException {
byte[] data = new byte[3 * 2 * 4];
ByteBuffer bb = ByteBuffer.wrap(data);
for (int i = 0; i < 6; i++)
bb.putFloat(i);
return bb.array();
}
@Override
protected boolean exists(String name) {
return true;
}
@Override
protected Iterable<String> readtext(String name, boolean unescapeNewlines) throws IOException {
assertTrue(unescapeNewlines);
return Arrays.asList("A", "B", "C");
}
private Word2VecMojoModel getModel() {
return _model;
}
}
}