/******************************************************************************* * Copyright (c) 2010 Haifeng Li * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ package smile.neighbor; import java.util.List; import smile.math.distance.Distance; import smile.sort.HeapSelect; /** * Brute force linear nearest neighbor search. This simplest solution computes * the distance from the query point to every other point in the database, * keeping track of the "best so far". There are no search data structures to * maintain, so linear search has no space complexity beyond the storage of * the database. Although it is very simple, naive search outperforms space * partitioning approaches (e.g. K-D trees) on higher dimensional spaces. * <p> * By default, the query object (reference equality) is excluded from the neighborhood. * You may change this behavior with <code>setIdenticalExcluded</code>. Note that * you may observe weird behavior with String objects. JVM will pool the string literal * objects. So the below variables * <code> * String a = "ABC"; * String b = "ABC"; * String c = "AB" + "C"; * </code> * are actually equal in reference test <code>a == b == c</code>. With toy data that you * type explicitly in the code, this will cause problems. Fortunately, the data would be * read from secondary storage in production. * </p> * * @param <T> the type of data objects. * * @author Haifeng Li */ public class LinearSearch<T> implements NearestNeighborSearch<T,T>, KNNSearch<T,T>, RNNSearch<T,T> { /** * The dataset of search space. */ private T[] data; /** * The distance function used to determine nearest neighbors. */ private Distance<T> distance; /** * Whether to exclude query object self from the neighborhood. */ private boolean identicalExcluded = true; /** * Constructor. By default, query object self will be excluded from search. */ public LinearSearch(T[] dataset, Distance<T> distance) { this.data = dataset; this.distance = distance; } @Override public String toString() { return String.format("Linear Search (%s)", distance); } /** * Set if exclude query object self from the neighborhood. */ public LinearSearch<T> setIdenticalExcluded(boolean excluded) { identicalExcluded = excluded; return this; } /** * Get whether if query object self be excluded from the neighborhood. */ public boolean isIdenticalExcluded() { return identicalExcluded; } @Override public Neighbor<T,T> nearest(T q) { T neighbor = null; int index = -1; double dist = Double.MAX_VALUE; for (int i = 0; i < data.length; i++) { if (q == data[i] && identicalExcluded) { continue; } double d = distance.d(q, data[i]); if (d < dist) { neighbor = data[i]; index = i; dist = d; } } return new SimpleNeighbor<>(neighbor, index, dist); } @Override public Neighbor<T,T>[] knn(T q, int k) { if (k <= 0) { throw new IllegalArgumentException("Invalid k: " + k); } if (k > data.length) { throw new IllegalArgumentException("Neighbor array length is larger than the dataset size"); } SimpleNeighbor<T> neighbor = new SimpleNeighbor<>(null, 0, Double.MAX_VALUE); @SuppressWarnings("unchecked") SimpleNeighbor<T>[] neighbors = (SimpleNeighbor<T>[]) java.lang.reflect.Array.newInstance(neighbor.getClass(), k); HeapSelect<Neighbor<T,T>> heap = new HeapSelect<>(neighbors); for (int i = 0; i < k; i++) { heap.add(neighbor); neighbor = new SimpleNeighbor<>(null, 0, Double.MAX_VALUE); } for (int i = 0; i < data.length; i++) { if (q == data[i] && identicalExcluded) { continue; } double dist = distance.d(q, data[i]); Neighbor<T,T> datum = heap.peek(); if (dist < datum.distance) { datum.distance = dist; datum.index = i; datum.key = data[i]; datum.value = data[i]; heap.heapify(); } } heap.sort(); return neighbors; } @Override public void range(T q, double radius, List<Neighbor<T,T>> neighbors) { if (radius <= 0.0) { throw new IllegalArgumentException("Invalid radius: " + radius); } for (int i = 0; i < data.length; i++) { if (q == data[i] && identicalExcluded) { continue; } double d = distance.d(q, data[i]); if (d <= radius) { neighbors.add(new SimpleNeighbor<>(data[i], i, d)); } } } }