package org.wikibrain.matrix.knn; import gnu.trove.set.TIntSet; import gnu.trove.set.hash.TIntHashSet; import org.junit.Test; import org.wikibrain.matrix.DenseMatrix; import org.wikibrain.matrix.DenseMatrixRow; import java.io.IOException; import static junit.framework.Assert.assertEquals; /** * @author Shilad Sen */ public class TestLSHForestNNFinder { DenseMatrix matrix; KNNFinder finder; private void makeSmall() throws IOException { matrix = TestUtils.createMatrix(1000, 20); LSHForestKNNFinder rp = new LSHForestKNNFinder(matrix); rp.build(); finder = rp; } private void makeBig() throws IOException { matrix = TestUtils.createMatrix(40000, 100); LSHForestKNNFinder rp = new LSHForestKNNFinder(matrix); rp.build(); finder = rp; } @Test public void testBuild() throws IOException { makeSmall(); } @Test public void testQueryCoverage() throws IOException { makeSmall(); int hits = 0; int iters = 10; for (int i = 0; i < iters; i++) { System.out.println("doing " + i); float[] v = TestUtils.randomVector(20); Neighborhood estimated = finder.query(v, 10, 1000, null); Neighborhood actual = actualNeighbors(v, matrix, 10); hits += overlap(estimated, actual); } assertEquals(iters * 10, hits); } @Test public void testQuery() throws IOException { makeBig(); int hits = 0; long elapsedTree = 0; long elapsedBruteForce = 0; int iters = 100; for (int i = 0; i < iters; i++) { float[] v = TestUtils.randomVector(100); long t1 = System.currentTimeMillis(); Neighborhood estimated = finder.query(v, 10, 100, null); long t2 = System.currentTimeMillis(); Neighborhood actual = actualNeighbors(v, matrix, 10); long t3 = System.currentTimeMillis(); elapsedBruteForce += (t3 - t2); elapsedTree += (t2 - t1); hits += overlap(estimated, actual); } System.out.println("Mean overlap is " + (1.0 * hits / iters)); System.out.println("Mean bruteforce millis is " + (1.0 * elapsedBruteForce / iters)); System.out.println("Mean optimized millis is " + (1.0 * elapsedTree / iters)); } private int overlap(Neighborhood n1, Neighborhood n2) { TIntSet ids = new TIntHashSet(); for (int i = 0; i < n1.size(); i++) { ids.add(n1.getId(i)); } int result = 0; for (int i = 0; i < n2.size(); i++) { if (ids.contains(n2.getId(i))) result++; } return result; } private static Neighborhood actualNeighbors(float [] v, DenseMatrix matrix, int n) throws IOException { BruteForceKNNFinder f = new BruteForceKNNFinder(matrix); return f.query(v, n, n, null); } }