/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.math.neighborhood; import java.util.Arrays; import java.util.List; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import org.apache.mahout.common.MahoutTestCase; import org.apache.mahout.common.RandomUtils; import org.apache.mahout.common.distance.EuclideanDistanceMeasure; import org.apache.mahout.math.DenseMatrix; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.Matrix; import org.apache.mahout.math.MatrixSlice; import org.apache.mahout.math.Vector; import org.apache.mahout.math.jet.math.Constants; import org.apache.mahout.math.random.MultiNormal; import org.apache.mahout.math.random.WeightedThing; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThanOrEqualTo; @RunWith(Parameterized.class) public class SearchSanityTest extends MahoutTestCase { private static final int NUM_DATA_POINTS = 1 << 13; private static final int NUM_DIMENSIONS = 20; private static final int NUM_PROJECTIONS = 3; private static final int SEARCH_SIZE = 30; private UpdatableSearcher searcher; private Matrix dataPoints; protected static Matrix multiNormalRandomData(int numDataPoints, int numDimensions) { Matrix data = new DenseMatrix(numDataPoints, numDimensions); MultiNormal gen = new MultiNormal(20); for (MatrixSlice slice : data) { slice.vector().assign(gen.sample()); } return data; } @Parameterized.Parameters public static List<Object[]> generateData() { RandomUtils.useTestSeed(); Matrix dataPoints = multiNormalRandomData(NUM_DATA_POINTS, NUM_DIMENSIONS); return Arrays.asList(new Object[][]{ {new ProjectionSearch(new EuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE), dataPoints}, {new FastProjectionSearch(new EuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE), dataPoints}, {new LocalitySensitiveHashSearch(new EuclideanDistanceMeasure(), SEARCH_SIZE), dataPoints}, }); } public SearchSanityTest(UpdatableSearcher searcher, Matrix dataPoints) { this.searcher = searcher; this.dataPoints = dataPoints; } @Test public void testExactMatch() { searcher.clear(); Iterable<MatrixSlice> data = dataPoints; final Iterable<MatrixSlice> batch1 = Iterables.limit(data, 300); List<MatrixSlice> queries = Lists.newArrayList(Iterables.limit(batch1, 100)); // adding the data in multiple batches triggers special code in some searchers searcher.addAllMatrixSlices(batch1); assertEquals(300, searcher.size()); Vector q = Iterables.get(data, 0).vector(); List<WeightedThing<Vector>> r = searcher.search(q, 2); assertEquals(0, r.get(0).getValue().minus(q).norm(1), 1.0e-8); final Iterable<MatrixSlice> batch2 = Iterables.limit(Iterables.skip(data, 300), 10); searcher.addAllMatrixSlices(batch2); assertEquals(310, searcher.size()); q = Iterables.get(data, 302).vector(); r = searcher.search(q, 2); assertEquals(0, r.get(0).getValue().minus(q).norm(1), 1.0e-8); searcher.addAllMatrixSlices(Iterables.skip(data, 310)); assertEquals(dataPoints.numRows(), searcher.size()); for (MatrixSlice query : queries) { r = searcher.search(query.vector(), 2); assertEquals("Distance has to be about zero", 0, r.get(0).getWeight(), 1.0e-6); assertEquals("Answer must be substantially the same as query", 0, r.get(0).getValue().minus(query.vector()).norm(1), 1.0e-8); assertTrue("Wrong answer must have non-zero distance", r.get(1).getWeight() > r.get(0).getWeight()); } } @Test public void testNearMatch() { searcher.clear(); List<MatrixSlice> queries = Lists.newArrayList(Iterables.limit(dataPoints, 100)); searcher.addAllMatrixSlicesAsWeightedVectors(dataPoints); MultiNormal noise = new MultiNormal(0.01, new DenseVector(20)); for (MatrixSlice slice : queries) { Vector query = slice.vector(); final Vector epsilon = noise.sample(); List<WeightedThing<Vector>> r = searcher.search(query, 2); query = query.plus(epsilon); assertEquals("Distance has to be small", epsilon.norm(2), r.get(0).getWeight(), 1.0e-1); assertEquals("Answer must be substantially the same as query", epsilon.norm(2), r.get(0).getValue().minus(query).norm(2), 1.0e-1); assertTrue("Wrong answer must be further away", r.get(1).getWeight() > r.get(0).getWeight()); } } @Test public void testOrdering() { searcher.clear(); Matrix queries = new DenseMatrix(100, 20); MultiNormal gen = new MultiNormal(20); for (int i = 0; i < 100; i++) { queries.viewRow(i).assign(gen.sample()); } searcher.addAllMatrixSlices(dataPoints); for (MatrixSlice query : queries) { List<WeightedThing<Vector>> r = searcher.search(query.vector(), 200); double x = 0; for (WeightedThing<Vector> thing : r) { assertTrue("Scores must be monotonic increasing", thing.getWeight() >= x); x = thing.getWeight(); } } } @Test public void testRemoval() { searcher.clear(); searcher.addAllMatrixSlices(dataPoints); //noinspection ConstantConditions if (searcher instanceof UpdatableSearcher) { List<Vector> x = Lists.newArrayList(Iterables.limit(searcher, 2)); int size0 = searcher.size(); List<WeightedThing<Vector>> r0 = searcher.search(x.get(0), 2); searcher.remove(x.get(0), 1.0e-7); assertEquals(size0 - 1, searcher.size()); List<WeightedThing<Vector>> r = searcher.search(x.get(0), 1); assertTrue("Vector should be gone", r.get(0).getWeight() > 0); assertEquals("Previous second neighbor should be first", 0, r.get(0).getValue().minus(r0.get(1).getValue()).norm (1), 1.0e-8); searcher.remove(x.get(1), 1.0e-7); assertEquals(size0 - 2, searcher.size()); r = searcher.search(x.get(1), 1); assertTrue("Vector should be gone", r.get(0).getWeight() > 0); // Vectors don't show up in iterator. for (Vector v : searcher) { assertTrue(x.get(0).minus(v).norm(1) > 1.0e-6); assertTrue(x.get(1).minus(v).norm(1) > 1.0e-6); } } else { try { List<Vector> x = Lists.newArrayList(Iterables.limit(searcher, 2)); searcher.remove(x.get(0), 1.0e-7); fail("Shouldn't be able to delete from " + searcher.getClass().getName()); } catch (UnsupportedOperationException e) { // good enough that UOE is thrown } } } @Test public void testSearchFirst() { searcher.clear(); searcher.addAll(dataPoints); for (Vector datapoint : dataPoints) { WeightedThing<Vector> first = searcher.searchFirst(datapoint, false); WeightedThing<Vector> second = searcher.searchFirst(datapoint, true); List<WeightedThing<Vector>> firstTwo = searcher.search(datapoint, 2); assertEquals("First isn't self", 0, first.getWeight(), 0); assertEquals("First isn't self", datapoint, first.getValue()); assertEquals("First doesn't match", first, firstTwo.get(0)); assertEquals("Second doesn't match", second, firstTwo.get(1)); } } @Test public void testSearchLimiting() { searcher.clear(); searcher.addAll(dataPoints); for (Vector datapoint : dataPoints) { List<WeightedThing<Vector>> firstTwo = searcher.search(datapoint, 2); assertThat("Search limit isn't respected", firstTwo.size(), is(lessThanOrEqualTo(2))); } } @Test public void testRemove() { searcher.clear(); for (int i = 0; i < dataPoints.rowSize(); ++i) { Vector datapoint = dataPoints.viewRow(i); searcher.add(datapoint); // As long as points are not searched for right after being added, in FastProjectionSearch, points are not // merged with the main list right away, so if a search for a point occurs before it's merged the pendingAdditions // list also needs to be looked at. // This used to not be the case for searchFirst(), thereby causing removal failures. if (i % 2 == 0) { assertTrue("Failed to find self [search]", searcher.search(datapoint, 1).get(0).getWeight() < Constants.EPSILON); assertTrue("Failed to find self [searchFirst]", searcher.searchFirst(datapoint, false).getWeight() < Constants.EPSILON); assertTrue("Failed to remove self", searcher.remove(datapoint, Constants.EPSILON)); } } } }