package hex.genmodel.algos.word2vec;
import com.google.common.io.ByteStreams;
import hex.genmodel.MojoModel;
import hex.genmodel.MojoReaderBackend;
import org.junit.Test;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import static org.junit.Assert.*;
public class Word2VecMojoModelTest {
@Test
public void testTransform0() throws Exception {
MojoModel mojo = Word2VecMojoReader.readFrom(new Word2VecMojoModelTest.ClasspathReaderBackend());
assertTrue(mojo instanceof WordEmbeddingModel);
WordEmbeddingModel m = (WordEmbeddingModel) mojo;
assertEquals(3, m.getVecSize());
assertArrayEquals(new float[]{0.0f,1.0f,0.2f}, m.transform0("a", new float[3]), 0.0001f);
assertArrayEquals(new float[]{1.0f,0.0f,0.8f}, m.transform0("b", new float[3]), 0.0001f);
assertNull(m.transform0("c", new float[3])); // out-of-dictionary word
}
private static class ClasspathReaderBackend implements MojoReaderBackend {
@Override
public BufferedReader getTextFile(String filename) throws IOException {
InputStream is = Word2VecMojoModelTest.class.getResourceAsStream(filename);
return new BufferedReader(new InputStreamReader(is));
}
@Override
public byte[] getBinaryFile(String filename) throws IOException {
InputStream is = Word2VecMojoModelTest.class.getResourceAsStream(filename);
return ByteStreams.toByteArray(is);
}
@Override
public boolean exists(String filename) {
return true;
}
}
}