/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.math.neighborhood;
import java.util.Arrays;
import java.util.List;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.distance.CosineDistanceMeasure;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.random.WeightedThing;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@RunWith(Parameterized.class)
public class SearchQualityTest {
private static final int NUM_DATA_POINTS = 1 << 14;
private static final int NUM_QUERIES = 1 << 10;
private static final int NUM_DIMENSIONS = 40;
private static final int NUM_RESULTS = 2;
private final Searcher searcher;
private final Matrix dataPoints;
private final Matrix queries;
private Pair<List<List<WeightedThing<Vector>>>, Long> reference;
private Pair<List<WeightedThing<Vector>>, Long> referenceSearchFirst;
@Parameterized.Parameters
public static List<Object[]> generateData() {
RandomUtils.useTestSeed();
Matrix dataPoints = LumpyData.lumpyRandomData(NUM_DATA_POINTS, NUM_DIMENSIONS);
Matrix queries = LumpyData.lumpyRandomData(NUM_QUERIES, NUM_DIMENSIONS);
DistanceMeasure distanceMeasure = new CosineDistanceMeasure();
Searcher bruteSearcher = new BruteSearch(distanceMeasure);
bruteSearcher.addAll(dataPoints);
Pair<List<List<WeightedThing<Vector>>>, Long> reference = getResultsAndRuntime(bruteSearcher, queries);
Pair<List<WeightedThing<Vector>>, Long> referenceSearchFirst =
getResultsAndRuntimeSearchFirst(bruteSearcher, queries);
double bruteSearchAvgTime = reference.getSecond() / (queries.numRows() * 1.0);
System.out.printf("BruteSearch: avg_time(1 query) %f[s]\n", bruteSearchAvgTime);
return Arrays.asList(new Object[][]{
// NUM_PROJECTIONS = 3
// SEARCH_SIZE = 10
{new ProjectionSearch(distanceMeasure, 3, 10), dataPoints, queries, reference, referenceSearchFirst},
{new FastProjectionSearch(distanceMeasure, 3, 10), dataPoints, queries, reference, referenceSearchFirst},
// NUM_PROJECTIONS = 5
// SEARCH_SIZE = 5
{new ProjectionSearch(distanceMeasure, 5, 5), dataPoints, queries, reference, referenceSearchFirst},
{new FastProjectionSearch(distanceMeasure, 5, 5), dataPoints, queries, reference, referenceSearchFirst},
}
);
}
public SearchQualityTest(Searcher searcher, Matrix dataPoints, Matrix queries,
Pair<List<List<WeightedThing<Vector>>>, Long> reference,
Pair<List<WeightedThing<Vector>>, Long> referenceSearchFirst) {
this.searcher = searcher;
this.dataPoints = dataPoints;
this.queries = queries;
this.reference = reference;
this.referenceSearchFirst = referenceSearchFirst;
}
@Test
public void testOverlapAndRuntimeSearchFirst() {
searcher.clear();
searcher.addAll(dataPoints);
Pair<List<WeightedThing<Vector>>, Long> results = getResultsAndRuntimeSearchFirst(searcher, queries);
int numFirstMatches = 0;
for (int i = 0; i < queries.numRows(); ++i) {
WeightedThing<Vector> referenceVector = referenceSearchFirst.getFirst().get(i);
WeightedThing<Vector> resultVector = results.getFirst().get(i);
if (referenceVector.getValue().equals(resultVector.getValue())) {
++numFirstMatches;
}
}
double bruteSearchAvgTime = reference.getSecond() / (queries.numRows() * 1.0);
double searcherAvgTime = results.getSecond() / (queries.numRows() * 1.0);
System.out.printf("%s: first matches %d [%d]; avg_time(1 query) %f(s) [%f]\n",
searcher.getClass().getName(), numFirstMatches, queries.numRows(),
searcherAvgTime, bruteSearchAvgTime);
assertEquals("Closest vector returned doesn't match", queries.numRows(), numFirstMatches);
assertTrue("Searcher " + searcher.getClass().getName() + " slower than brute",
bruteSearchAvgTime > searcherAvgTime);
}
@Test
public void testOverlapAndRuntime() {
searcher.clear();
searcher.addAll(dataPoints);
Pair<List<List<WeightedThing<Vector>>>, Long> results = getResultsAndRuntime(searcher, queries);
int numFirstMatches = 0;
int numMatches = 0;
StripWeight stripWeight = new StripWeight();
for (int i = 0; i < queries.numRows(); ++i) {
List<WeightedThing<Vector>> referenceVectors = reference.getFirst().get(i);
List<WeightedThing<Vector>> resultVectors = results.getFirst().get(i);
if (referenceVectors.get(0).getValue().equals(resultVectors.get(0).getValue())) {
++numFirstMatches;
}
for (Vector v : Iterables.transform(referenceVectors, stripWeight)) {
for (Vector w : Iterables.transform(resultVectors, stripWeight)) {
if (v.equals(w)) {
++numMatches;
}
}
}
}
double bruteSearchAvgTime = reference.getSecond() / (queries.numRows() * 1.0);
double searcherAvgTime = results.getSecond() / (queries.numRows() * 1.0);
System.out.printf("%s: first matches %d [%d]; total matches %d [%d]; avg_time(1 query) %f(s) [%f]\n",
searcher.getClass().getName(), numFirstMatches, queries.numRows(),
numMatches, queries.numRows() * NUM_RESULTS, searcherAvgTime, bruteSearchAvgTime);
assertEquals("Closest vector returned doesn't match", queries.numRows(), numFirstMatches);
assertTrue("Searcher " + searcher.getClass().getName() + " slower than brute",
bruteSearchAvgTime > searcherAvgTime);
}
public static Pair<List<List<WeightedThing<Vector>>>, Long> getResultsAndRuntime(Searcher searcher,
Iterable<? extends Vector> queries) {
long start = System.currentTimeMillis();
List<List<WeightedThing<Vector>>> results = searcher.search(queries, NUM_RESULTS);
long end = System.currentTimeMillis();
return new Pair<List<List<WeightedThing<Vector>>>, Long>(results, end - start);
}
public static Pair<List<WeightedThing<Vector>>, Long> getResultsAndRuntimeSearchFirst(
Searcher searcher, Iterable<? extends Vector> queries) {
long start = System.currentTimeMillis();
List<WeightedThing<Vector>> results = searcher.searchFirst(queries, false);
long end = System.currentTimeMillis();
return new Pair<List<WeightedThing<Vector>>, Long>(results, end - start);
}
static class StripWeight implements Function<WeightedThing<Vector>, Vector> {
@Override
public Vector apply(WeightedThing<Vector> input) {
Preconditions.checkArgument(input != null, "input is null");
//noinspection ConstantConditions
return input.getValue();
}
}
}