package hex.word2vec; import org.junit.*; import water.DKV; import water.Key; import water.Scope; import water.TestUtil; import water.fvec.Frame; import water.fvec.TestFrameBuilder; import water.fvec.Vec; import water.parser.BufferedString; import water.util.ArrayUtils; import water.util.Log; import java.util.*; import static org.hamcrest.CoreMatchers.notNullValue; import static org.junit.Assume.assumeThat; import static org.hamcrest.CoreMatchers.is; import static org.junit.Assert.*; public class Word2VecTest extends TestUtil { @BeforeClass() public static void setup() { stall_till_cloudsize(1); } @Test public void testW2V_SG_HSM_small() { String[] words = new String[220]; for (int i = 0; i < 200; i += 2) { words[i] = "a"; words[i + 1] = "b"; } for (int i = 200; i < 220; i += 2) { words[i] = "a"; words[i + 1] = "c"; } Scope.enter(); try { Vec v = Scope.track(svec(words)); Frame fr = Scope.track(new Frame(Key.<Frame>make(), new String[]{"Words"}, new Vec[]{v})); DKV.put(fr); Word2VecModel.Word2VecParameters p = new Word2VecModel.Word2VecParameters(); p._train = fr._key; p._min_word_freq = 5; p._word_model = Word2Vec.WordModel.SkipGram; p._norm_model = Word2Vec.NormModel.HSM; p._vec_size = 10; p._window_size = 5; p._sent_sample_rate = 0.001f; p._init_learning_rate = 0.025f; p._epochs = 1; Word2VecModel w2vm = (Word2VecModel) Scope.track_generic(new Word2Vec(p).trainModel().get()); Map<String, Float> hm = w2vm.findSynonyms("a", 2); logResults(hm); assertEquals(new HashSet<>(Arrays.asList("b", "c")), hm.keySet()); Vec testWordVec = Scope.track(svec("a", "b", "c", "Unseen", null)); Frame wv = Scope.track(w2vm.transform(testWordVec, Word2VecModel.AggregateMethod.NONE)); assertEquals(10, wv.numCols()); for (int i = 0; i < 10; i++) { for (int j = 0; j < 3; j++) assertFalse(wv.vec(i).isNA(j)); // known words for (int j = 3; j < 5; j++) assertTrue(wv.vec(i).isNA(j)); // unseen & missing words } } finally { Scope.exit(); } } @Test public void testW2V_pretrained() { String[] words = new String[1000]; double[] v1 = new double[words.length]; double[] v2 = new double[words.length]; for (int i = 0; i < words.length; i++) { words[i] = "word" + i; v1[i] = i / (float) words.length; v2[i] = 1 - v1[i]; } Scope.enter(); Frame pretrained = new TestFrameBuilder() .withName("w2v-pretrained") .withColNames("Word", "V1", "V2") .withVecTypes(Vec.T_STR, Vec.T_NUM, Vec.T_NUM) .withDataForCol(0, words) .withDataForCol(1, v1) .withDataForCol(2, v2) .withChunkLayout(100, 100, 20, 80, 100, 100, 100, 100, 100, 100, 100) .build(); Scope.track(pretrained); try { Word2VecModel.Word2VecParameters p = new Word2VecModel.Word2VecParameters(); p._vec_size = 2; p._pre_trained = pretrained._key; Word2VecModel w2vm = (Word2VecModel) Scope.track_generic(new Word2Vec(p).trainModel().get()); for (int i = 0; i < words.length; i++) { float[] wordVector = w2vm.transform(words[i]); assertArrayEquals("wordvec " + i, new float[]{(float) v1[i], (float) v2[i]}, wordVector, 0.0001f); } } finally { Scope.exit(); } } @Test public void testW2V_toFrame() { Random r = new Random(); String[] words = new String[1000]; double[] v1 = new double[words.length]; double[] v2 = new double[words.length]; for (int i = 0; i < words.length; i++) { words[i] = "word" + i; v1[i] = r.nextDouble(); v2[i] = r.nextDouble(); } try { Scope.enter(); Frame expected = new TestFrameBuilder() .withName("w2v") .withColNames("Word", "V1", "V2") .withVecTypes(Vec.T_STR, Vec.T_NUM, Vec.T_NUM) .withDataForCol(0, words) .withDataForCol(1, v1) .withDataForCol(2, v2) .withChunkLayout(100, 900) .build(); Scope.track(expected); Word2VecModel.Word2VecParameters p = new Word2VecModel.Word2VecParameters(); p._vec_size = 2; p._pre_trained = expected._key; Word2VecModel w2vm = (Word2VecModel) Scope.track_generic(new Word2Vec(p).trainModel().get()); // convert to a Frame Frame result = Scope.track(w2vm.toFrame()); assertArrayEquals(expected._names, result._names); assertStringVecEquals(expected.vec(0), result.vec(0)); assertVecEquals(expected.vec(1), result.vec(1), 0.0001); assertVecEquals(expected.vec(2), result.vec(2), 0.0001); } finally { Scope.exit(); } } @Test public void testW2V_SG_HSM() { assumeThat("word2vec test enabled", System.getProperty("testW2V"), is(notNullValue())); // ignored by default Frame fr = parse_test_file("bigdata/laptop/text8.gz", "NA", 0, new byte[]{Vec.T_STR}); Word2VecModel w2vm = null; try { Word2VecModel.Word2VecParameters p = new Word2VecModel.Word2VecParameters(); p._train = fr._key; p._min_word_freq = 5; p._word_model = Word2Vec.WordModel.SkipGram; p._norm_model = Word2Vec.NormModel.HSM; p._vec_size = 100; p._window_size = 4; p._sent_sample_rate = 0.001f; p._init_learning_rate = 0.025f; p._epochs = 10; w2vm = new Word2Vec(p).trainModel().get(); Map<String, Float> hm = w2vm.findSynonyms("dog", 20); logResults(hm); assertTrue(hm.containsKey("cat") || hm.containsKey("dogs") || hm.containsKey("hound")); } finally { fr.remove(); if( w2vm != null) w2vm.delete(); } } @Test public void testTransformAggregate() { Scope.enter(); try { Vec v = Scope.track(svec("a", "b")); Frame fr = Scope.track(new Frame(Key.<Frame>make(), new String[]{"Words"}, new Vec[]{v})); DKV.put(fr); // build an arbitrary w2v model & overwrite the learned vector with fixed values Word2VecModel.Word2VecParameters p = new Word2VecModel.Word2VecParameters(); p._train = fr._key; p._min_word_freq = 0; p._epochs = 1; p._vec_size = 2; Word2VecModel w2vm = (Word2VecModel) Scope.track_generic(new Word2Vec(p).trainModel().get()); w2vm._output._vecs = new float[] {1.0f, 0.0f, 0.0f, 1.0f}; DKV.put(w2vm); String[][] chunks = { new String[] {"a", "b", null, "a", "c", null, "c", null, "a", "a"}, new String[] {"a", "b", null}, new String[] {null, null}, new String[] {"b", "b", "a"}, new String[] {"b"} // no terminator at the end }; long[] layout = new long[chunks.length]; String[] sentences = new String[0]; for (int i = 0; i < chunks.length; i++) { sentences = ArrayUtils.append(sentences, chunks[i]); layout[i] = chunks[i].length; } Frame f = new TestFrameBuilder() .withName("data") .withColNames("Sentences") .withVecTypes(Vec.T_STR) .withDataForCol(0, sentences) .withChunkLayout(layout) .build(); Frame result = Scope.track(w2vm.transform(f.vec(0), Word2VecModel.AggregateMethod.AVERAGE)); Vec expectedAs = Scope.track(dvec(0.5, 1.0, Double.NaN, 0.75, Double.NaN, Double.NaN, 0.25)); Vec expectedBs = Scope.track(dvec(0.5, 0.0, Double.NaN, 0.25, Double.NaN, Double.NaN, 0.75)); assertVecEquals(expectedAs, result.vec(w2vm._output._vocab.get(new BufferedString("a"))), 0.0001); assertVecEquals(expectedBs, result.vec(w2vm._output._vocab.get(new BufferedString("b"))), 0.0001); } finally { Scope.exit(); } } private void logResults(Map<String, Float> hm) { List<Map.Entry<String, Float>> result = new ArrayList<>(hm.entrySet()); Collections.sort(result, new Comparator<Map.Entry<String, Float>>() { @Override public int compare(Map.Entry<String, Float> o1, Map.Entry<String, Float> o2) { return o2.getValue().compareTo(o1.getValue()); // reverse sort } }); int i = 0; for (Map.Entry entry : result) Log.info((i++) + ". " + entry.getKey() + ", " + entry.getValue()); } }