package com.cyc.tool.distributedrepresentations; /* * #%L * DistributedRepresentations * %% * Copyright (C) 2015 Cycorp, Inc * %% * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * #L% */ import java.io.IOException; import java.util.Arrays; import java.util.List; import org.junit.AfterClass; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import org.junit.BeforeClass; import org.junit.Test; /** * Tests for Word2VecSpace. */ public class Word2VecSpaceIT { static List<String> cr = Arrays.asList("Chinese", "river"); static Word2VecSpace mySpace; public Word2VecSpaceIT() { } @BeforeClass public static void setUpClass() throws IOException { mySpace = GoogleNewsW2VSpace.get(); } @AfterClass public static void tearDownClass() { mySpace = null; } // @Test public void distanceTest() { assertEquals(1.0, mySpace.cosineSimilarity("skimpy bathing suits", "skimpy_bathing_suits"), 0.00000001); assertEquals(0.24279, mySpace.cosineSimilarity("skimpy bathing suits", "Giant Octopus"), 0.0001); assertEquals(0.54801, mySpace.cosineSimilarity("skimpy bathing suits", "bathing suits"), 0.0001); assertEquals(0.645069, mySpace.cosineSimilarity("apple", "pear"), 0.0001); assertEquals(0.20749, mySpace.cosineSimilarity("apple", "cat"), 0.0001); assertTrue(mySpace.cosineSimilarity("apple", "pear") > mySpace.cosineSimilarity("apple", "cat")); } @Test public void getVectorTest1() { assertEquals(-0.05338118f, (mySpace.getVector("skimpy bathing suits")[5]), 0.000001); assertEquals(0.047296f, (mySpace.getVector("skimpy bathing suits")[105]), 0.000001); } @Test public void getVectorTest2a() { assertEquals(-0.049851f, (mySpace.getVector("Chinese")[0]), 0.000001); assertEquals(-0.090444f, (mySpace.getVector("Chinese")[5]), 0.000001); } @Test public void getVectorTest2b() { assertEquals(0.002663f, (mySpace.getVector("river")[0]), 0.000001); assertEquals(-0.029231f, (mySpace.getVector("river")[5]), 0.000001); } @Test public void googleDistanceTest1() { try { assertEquals(0.667376, mySpace.googleSimilarity(cr, "Yangtze_River"), 0.0001); } catch (Word2VecSpace.NoWordToVecVectorForTerm ex) { fail("took unexpected exception:" + ex); } } @Test public void googleDistanceTest2() { try { assertEquals(0.594108, mySpace.googleSimilarity(cr, "Hongze_Lake"), 0.0001); } catch (Word2VecSpace.NoWordToVecVectorForTerm ex) { fail("took unexpected exception:" + ex); } } @Test public void googleDistanceTest3() { try { assertEquals(0.604726, mySpace.googleSimilarity(cr, "Huangpu_River"), 0.0001); } catch (Word2VecSpace.NoWordToVecVectorForTerm ex) { fail("took unexpected exception:" + ex); } } @Test public void googleNormVectorTest0() { try { float[] norm = mySpace.getGoogleNormedVector(cr); assertEquals(-0.032075, norm[0], 0.000001); } catch (Word2VecSpace.NoWordToVecVectorForTerm ex) { fail("took unexpected exception:" + ex); } } @Test public void googleNormVectorTest100() { float[] norm; try { norm = mySpace.getGoogleNormedVector(cr); assertEquals(-0.095236, norm[100], 0.000001); } catch (Word2VecSpace.NoWordToVecVectorForTerm ex) { fail("took unexpected exception:" + ex); } } @Test public void googleNormVectorTest5() { try { float[] norm = mySpace.getGoogleNormedVector(cr); assertEquals(-0.081347, norm[5], 0.000001); } catch (Word2VecSpace.NoWordToVecVectorForTerm ex) { fail("took unexpected exception:" + ex); } } @Test public void googleNormVectorTest50() { try { float[] norm = mySpace.getGoogleNormedVector(cr); assertEquals(0.080537, norm[50], 0.000001); } catch (Word2VecSpace.NoWordToVecVectorForTerm ex) { fail("took unexpected exception:" + ex); } } /** * Test if known terms have been loaded from the Word2Vec file or DB */ @Test public void knownTermTest() { // System.out.println("DB Size:" + vectors.size()); assertTrue(mySpace.knownTerm("Yathra")); assertTrue(mySpace.knownTerm("skimpy bathing suits")); assertTrue(mySpace.knownTerm("Giant_Octopus")); assertTrue(mySpace.knownTerm("Yangtze_River")); assertTrue(mySpace.knownTerm("Chinese")); // assertTrue(mySpace.knownTerm("Chinese River")); } // @Test // public void findNearbyTerms1() { // try { // long t1 = System.currentTimeMillis(); // List<ConceptMatch> matches = mySpace.findNearestNForWithInputTermFiltering(cr, 40); // IntStream.range(0, matches.size()) // .forEach(i -> { // System.out.println(i + " " + matches.get(i).toString()); // }); // System.out.println("Took " + (System.currentTimeMillis() - t1) + "ms"); // assertEquals(matches.get(0).getTerm(), "Yangtze_River"); // assertEquals(0.604726, matches.get(5).getSimilarity(), 0.000001); // // assertEquals(matches.get(23).getTerm(), "rivers"); // } catch (Word2VecSpace.NoWordToVecVectorForTerm ex) { // fail("took unexpected exception:" + ex); // } // } // // @Test // // public void findNearbyTerms2() { // try { // long t1 = System.currentTimeMillis(); // List<ConceptMatch> matches = mySpace.findNearestNForWithInputTermFiltering(Arrays.asList("gangplank"), 40); // IntStream.range(0, matches.size()) // .forEach(i -> { // System.out.println(i + " " + matches.get(i).toString()); // }); // System.out.println("Took " + (System.currentTimeMillis() - t1) + "ms"); // } catch (Word2VecSpace.NoWordToVecVectorForTerm ex) { // fail("took unexpected exception:" + ex); // } // } @Test public void testNGramsFor() { List<String> res = Word2VecSpace.nGramsFor(Arrays.asList("this", "is", "a", "test")); // System.out.println("test: "+res+" len:"+res.size()); assertEquals(10, res.size()); } @Test public void testNGramsForCR() { List<String> res = Word2VecSpace.nGramsFor(cr); System.out.println("test: " + res + " len:" + res.size()); assertEquals(3, res.size()); } }