/*
* File: KNearestNeighborKDTree.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Aug 4, 2009, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.nearest;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.math.DivergenceFunction;
import gov.sandia.cognition.math.geometry.KDTree;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.distance.EuclideanDistanceMetric;
import gov.sandia.cognition.math.Metric;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Pair;
import gov.sandia.cognition.util.Summarizer;
import java.util.ArrayList;
import java.util.Collection;
/**
* A KDTree-based implementation of the k-nearest neighbor algorithm. This
* algorithm has a O(n log(n)) construction time and a O(log(n)) evaluate time.
* @param <InputType> Type of Vectorizable data upon which we determine
* similarity.
* @param <OutputType> Output of the evaluator, like Matrix, Double, String
* @see gov.sandia.cognition.math.geometry.KDTree
* @author Kevin R. Dixon
* @since 3.0
*/
@PublicationReference(
author="Wikipedia",
title="k-nearest neighbor algorithm",
type=PublicationType.WebPage,
year=2008,
url="http://en.wikipedia.org/wiki/K-nearest_neighbor_algorithm"
)
public class KNearestNeighborKDTree<InputType extends Vectorizable,OutputType>
extends AbstractKNearestNeighbor<InputType,OutputType>
implements Evaluator<InputType, OutputType>
{
/**
* KDTree that holds the data to search for neighbors.
*/
private KDTree<InputType,OutputType,InputOutputPair<? extends InputType,OutputType>> data;
/**
* Creates a new instance of KNearestNeighborKDTree
*/
public KNearestNeighborKDTree()
{
this( DEFAULT_K, null, null, null );
}
/**
* Creates a new instance of KNearestNeighborKDTree
* @param k
* Number of neighbors to consider, must be greater than zero
* @param data
* KDTree that holds the data to search for neighbors.
* @param distanceFunction
* Distance metric that determines how "far" two objects are apart,
* where lower values indicate two objects are more similar
* @param averager
* KDTree that holds the data to search for neighbors.
*/
public KNearestNeighborKDTree(
int k,
KDTree<InputType,OutputType,InputOutputPair<? extends InputType,OutputType>> data,
Metric<? super InputType> distanceFunction,
Summarizer<? super OutputType,? extends OutputType> averager )
{
super( k, distanceFunction, averager );
this.setData(data);
}
@Override
public KNearestNeighborKDTree<InputType,OutputType> clone()
{
KNearestNeighborKDTree<InputType,OutputType> clone =
(KNearestNeighborKDTree<InputType,OutputType>) super.clone();
clone.setData( ObjectUtil.cloneSafe( this.getData() ) );
return clone;
}
/**
* Setter for distanceFunction
* @return
* Distance metric that determines how "far" two objects are apart,
* where lower values indicate two objects are more similar.
*/
@SuppressWarnings("unchecked")
@Override
public Metric<? super InputType> getDivergenceFunction()
{
return (Metric<? super InputType>) super.getDivergenceFunction();
}
@Override
@SuppressWarnings("unchecked")
public void setDivergenceFunction(
DivergenceFunction<? super InputType, ? super InputType> divergenceFunction)
{
this.setDivergenceFunction( (Metric<? super InputType>) divergenceFunction );
}
/**
* Sets the Metric to use.
* @param divergenceFunction
* Metric that determines closeness.
*/
public void setDivergenceFunction(
Metric<? super InputType> divergenceFunction)
{
super.setDivergenceFunction(divergenceFunction);
}
/**
* Getter for data
* @return
* KDTree that holds the data to search for neighbors.
*/
public KDTree<InputType, OutputType,InputOutputPair<? extends InputType,OutputType>> getData()
{
return this.data;
}
/**
* Setter for data
* @param data
* KDTree that holds the data to search for neighbors.
*/
public void setData(
KDTree<InputType, OutputType,InputOutputPair<? extends InputType,OutputType>> data)
{
this.data = data;
}
@Override
protected Collection<OutputType> computeNeighborhood(
InputType key)
{
Collection<InputOutputPair<? extends InputType,OutputType>> neighbors =
this.getData().findNearest(key, this.getK(), this.getDivergenceFunction());
ArrayList<OutputType> outputs =
new ArrayList<OutputType>( neighbors.size() );
for( Pair<? extends InputType,OutputType> neighbor : neighbors )
{
outputs.add( neighbor.getSecond() );
}
return outputs;
}
/**
* Rebalances the internal KDTree to make the search more efficient. This
* is an O(n log(n)) operation with n samples.
*/
public void rebalance()
{
this.setData( this.getData().reblanace() );
}
/**
* This is a BatchLearner interface for creating a new KNearestNeighbor
* from a given dataset, simply a pass-through to the constructor of
* KNearestNeighbor
* @param <InputType> Type of data upon which the KNearestNeighbor operates,
* something like Vector, Double, or String
* @param <OutputType> Output of the evaluator, like Matrix, Double, String
*/
public static class Learner<InputType extends Vectorizable, OutputType>
extends KNearestNeighborKDTree<InputType,OutputType>
implements SupervisedBatchLearner<InputType,OutputType,KNearestNeighborKDTree<InputType, OutputType>>
{
/**
* Default constructor.
*/
public Learner()
{
this( null );
}
/**
* Creates a new instance of Learner.
* @param averager
* Creates a single object from a collection of data.
*/
public Learner(
Summarizer<? super OutputType,? extends OutputType> averager )
{
this( DEFAULT_K, EuclideanDistanceMetric.INSTANCE, averager );
}
/**
* Creates a new instance of Learner
* @param k
* Number of neighbors to consider, must be greater than zero
* @param divergenceFunction
* Divergence function that determines how "far" two objects are apart,
* where lower values indicate two objects are more similar
* @param averager
* Creates a single object from a collection of data
*/
public Learner(
int k,
Metric<? super Vectorizable> divergenceFunction,
Summarizer<? super OutputType,? extends OutputType> averager )
{
super( k, null, divergenceFunction, averager );
}
/**
* Creates a new KNearestNeighbor from a Collection of InputType.
* We build a balanced KDTree with the data, which is an O(n log(n))
* operator for n data points.
* @param data Dataset from which to create a new KNearestNeighbor
* @return
* KNearestNeighbor based on the given dataset with a balanced
* KDTree.
*/
public KNearestNeighborKDTree<InputType, OutputType> learn(
Collection<? extends InputOutputPair<? extends InputType,OutputType>> data )
{
@SuppressWarnings("unchecked")
KNearestNeighborKDTree<InputType, OutputType> clone = this.clone();
KDTree<InputType,OutputType,InputOutputPair<? extends InputType,OutputType>> tree =
KDTree.createBalanced(data);
clone.setData( tree );
return clone;
}
}
}