/*******************************************************************************
* Copyright (c) 2010 Haifeng Li
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
package smile.classification;
import java.io.Serializable;
import java.util.Arrays;
import smile.math.Math;
import smile.math.distance.Distance;
import smile.math.distance.EuclideanDistance;
import smile.math.distance.Metric;
import smile.neighbor.CoverTree;
import smile.neighbor.KDTree;
import smile.neighbor.KNNSearch;
import smile.neighbor.LinearSearch;
import smile.neighbor.Neighbor;
/**
* K-nearest neighbor classifier. The k-nearest neighbor algorithm (k-NN) is
* a method for classifying objects by a majority vote of its neighbors,
* with the object being assigned to the class most common amongst its k
* nearest neighbors (k is a positive integer, typically small).
* k-NN is a type of instance-based learning, or lazy learning where the
* function is only approximated locally and all computation
* is deferred until classification.
* <p>
* The best choice of k depends upon the data; generally, larger values of
* k reduce the effect of noise on the classification, but make boundaries
* between classes less distinct. A good k can be selected by various
* heuristic techniques, e.g. cross-validation. In binary problems, it is
* helpful to choose k to be an odd number as this avoids tied votes.
* <p>
* A drawback to the basic majority voting classification is that the classes
* with the more frequent instances tend to dominate the prediction of the
* new object, as they tend to come up in the k nearest neighbors when
* the neighbors are computed due to their large number. One way to overcome
* this problem is to weight the classification taking into account the
* distance from the test point to each of its k nearest neighbors.
* <p>
* Often, the classification accuracy of k-NN can be improved significantly
* if the distance metric is learned with specialized algorithms such as
* Large Margin Nearest Neighbor or Neighborhood Components Analysis.
* <p>
* Nearest neighbor rules in effect compute the decision boundary in an
* implicit manner. It is also possible to compute the decision boundary
* itself explicitly, and to do so in an efficient manner so that the
* computational complexity is a function of the boundary complexity.
* <p>
* The nearest neighbor algorithm has some strong consistency results. As
* the amount of data approaches infinity, the algorithm is guaranteed to
* yield an error rate no worse than twice the Bayes error rate (the minimum
* achievable error rate given the distribution of the data). k-NN is
* guaranteed to approach the Bayes error rate, for some value of k (where k
* increases as a function of the number of data points).
*
* @author Haifeng Li
*/
public class KNN<T> implements SoftClassifier<T>, Serializable {
private static final long serialVersionUID = 1L;
/**
* The data structure for nearest neighbor search.
*/
private KNNSearch<T, T> knn;
/**
* The labels of training samples.
*/
private int[] y;
/**
* The number of neighbors for decision.
*/
private int k;
/**
* The number of classes.
*/
private int c;
/**
* Trainer for KNN classifier.
*/
public static class Trainer<T> extends ClassifierTrainer<T> {
/**
* The number of neighbors.
*/
private int k;
/**
* The distance functor.
*/
private Distance<T> distance;
/**
* Constructor.
*
* @param distance the distance metric functor.
* @param k the number of neighbors.
*/
public Trainer(Distance<T> distance, int k) {
if (k < 1) {
throw new IllegalArgumentException("Invalid k of k-NN: " + k);
}
this.distance = distance;
this.k = k;
}
@Override
public KNN<T> train(T[] x, int[] y) {
return new KNN<>(x, y, distance, k);
}
}
/**
* Constructor.
* @param knn k-nearest neighbor search data structure of training instances.
* @param y training labels in [0, c), where c is the number of classes.
* @param k the number of neighbors for classification.
*/
public KNN(KNNSearch<T, T> knn, int[] y, int k) {
this.knn = knn;
this.k = k;
this.y = y;
// class label set.
int[] labels = Math.unique(y);
Arrays.sort(labels);
for (int i = 0; i < labels.length; i++) {
if (labels[i] < 0) {
throw new IllegalArgumentException("Negative class label: " + labels[i]);
}
if (i > 0 && labels[i] - labels[i-1] > 1) {
throw new IllegalArgumentException("Missing class: " + labels[i]+1);
}
}
c = labels.length;
if (c < 2) {
throw new IllegalArgumentException("Only one class.");
}
}
/**
* Constructor. By default, this is a 1-NN classifier.
* @param x training samples.
* @param y training labels in [0, c), where c is the number of classes.
* @param distance the distance measure for finding nearest neighbors.
*/
public KNN(T[] x, int[] y, Distance<T> distance) {
this(x, y, distance, 1);
}
/**
* Learn the K-NN classifier from data of any generalized type with a given
* distance definition.
* @param k the number of neighbors for classification.
* @param x training samples.
* @param y training labels in [0, c), where c is the number of classes.
* @param distance the distance measure for finding nearest neighbors.
*/
public KNN(T[] x, int[] y, Distance<T> distance, int k) {
if (x.length != y.length) {
throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
}
if (k < 1) {
throw new IllegalArgumentException("Illegal k = " + k);
}
// class label set.
int[] labels = Math.unique(y);
Arrays.sort(labels);
for (int i = 0; i < labels.length; i++) {
if (labels[i] < 0) {
throw new IllegalArgumentException("Negative class label: " + labels[i]);
}
if (i > 0 && labels[i] - labels[i-1] > 1) {
throw new IllegalArgumentException("Missing class: " + labels[i]+1);
}
}
c = labels.length;
if (c < 2) {
throw new IllegalArgumentException("Only one class.");
}
this.y = y;
this.k = k;
if (distance instanceof Metric) {
knn = new CoverTree<>(x, (Metric<T>) distance);
} else {
knn = new LinearSearch<>(x, distance);
}
}
/**
* Learn the 1-NN classifier from data of type double[].
* @param x the training samples.
* @param y training labels in [0, c), where c is the number of classes.
*/
public static KNN<double[]> learn(double[][] x, int[] y) {
return learn(x, y, 1);
}
/**
* Learn the K-NN classifier from data of type double[].
* @param k the number of neighbors for classification.
* @param x training samples.
* @param y training labels in [0, c), where c is the number of classes.
*/
public static KNN<double[]> learn(double[][] x, int[] y, int k) {
if (x.length != y.length) {
throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
}
if (k < 1) {
throw new IllegalArgumentException("Illegal k = " + k);
}
KNNSearch<double[], double[]> knn = null;
if (x[0].length < 10) {
knn = new KDTree<>(x, x);
} else {
knn = new CoverTree<>(x, new EuclideanDistance());
}
return new KNN<>(knn, y, k);
}
@Override
public int predict(T x) {
return predict(x, null);
}
@Override
public int predict(T x, double[] posteriori) {
if (posteriori != null && posteriori.length != c) {
throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, c));
}
Neighbor<T,T>[] neighbors = knn.knn(x, k);
if (k == 1) {
return y[neighbors[0].index];
}
int[] count = new int[c];
for (int i = 0; i < k; i++) {
count[y[neighbors[i].index]]++;
}
if (posteriori != null) {
for (int i = 0; i < c; i++) {
posteriori[i] = (double) count[i] / k;
}
}
int max = 0;
int idx = 0;
for (int i = 0; i < c; i++) {
if (count[i] > max) {
max = count[i];
idx = i;
}
}
return idx;
}
}