/* * File: LocallyWeightedFunction.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright Dec 2, 2007, Sandia Corporation. Under the terms of Contract * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by * or on behalf of the U.S. Government. Export of this program may require a * license from the United States Government. See CopyrightHistory.txt for * complete details. * */ package gov.sandia.cognition.learning.algorithm.regression; import gov.sandia.cognition.annotation.PublicationReference; import gov.sandia.cognition.annotation.PublicationType; import gov.sandia.cognition.evaluator.Evaluator; import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner; import gov.sandia.cognition.learning.data.DatasetUtil; import gov.sandia.cognition.learning.data.DefaultWeightedInputOutputPair; import gov.sandia.cognition.learning.data.InputOutputPair; import gov.sandia.cognition.learning.data.WeightedInputOutputPair; import gov.sandia.cognition.learning.function.kernel.Kernel; import gov.sandia.cognition.util.AbstractCloneableSerializable; import java.util.ArrayList; import java.util.Collection; /** * LocallyWeightedFunction is a generalization of the k-nearest neighbor * concept, also known as "Instance-Based Learning", "Memory-Based Learning", * "Nonparametric Regression", "Case-Based Regression", or * "Kernel-Based Regression". This approach essentially has no up-front * learning time, but creates a local function approximation in response to * a evaluate() call. The local function approximation is created by weighting * the original dataset by a value given by a Kernel against each input sample * in the dataset. * <BR><BR> * KernelWeightedRobustRegression is different from LocallyWeightedFunction in * that KWRR creates a global function approximator and holds for all inputs. * Thus, up-front learning time for KWRR is relatively high, but evaluation time * is relatively low. On the other hand, LWL creates a local function * approximator in response to each evaluation, and LWL does not create a global * function approximator. As such, LWL has (almost) no up-front learning time, * but each evaluation requires relatively high computation. The cost of LWL * function evaluation depends strongly on the type of learner given to the * algorithm. If you use fast or closed-form learners, then you may not notice * the evaluation time. But if you use some brain-dead iterative technique, * like Gradient Descent, then use LWL at your own risk. * <BR><BR> * KWRR is more appropriate when you know the general structure of your data, * but it is riddled with outliers. LWL is more appropriate when you don't * know/understand the general trend of your data AND you can afford evaluation * time to be somewhat costly. * * @see KernelWeightedRobustRegression * * @param <InputType> * Input class to map onto the Output class * @param <OutputType> * Output of the Evaluator * @author Kevin R. Dixon */ @PublicationReference( author="Andrew W. Moore", title="Instance-based learning (aka Case-based or Memory-based or non-parametric)", type=PublicationType.WebPage, year=2006, url="http://www.autonlab.org/tutorials/mbl.html" ) public class LocallyWeightedFunction<InputType, OutputType> implements Evaluator<InputType, OutputType> { /** * Kernel that provides the weights between an input and each sample * in the input dataset */ private Kernel<? super InputType> kernel; /** * Learner that takes the Collection of WeightedInputOutputPairs from * the Kernel reweighting and creates a local function approximation at * the given input. I would strongly recommend using fast or closed-form * learners for this. */ private SupervisedBatchLearner<InputType,OutputType,?> learner; /** * Original (weighted) dataset */ private ArrayList<WeightedInputOutputPair<InputType, OutputType>> rawData; /** * Dataset containing the weights in response to an evaluate() call. The * weights in this dataset will be a product of the original dataset * weights times the weights from the Kernel response to the given input * to the evaluate() method call. */ private ArrayList<DefaultWeightedInputOutputPair<InputType, OutputType>> locallyWeightedData; /** * Local function approximator created from the learner and the * locallyWeightedData, may be null if you haven't called evaluate() yet */ private Evaluator<? super InputType, ? extends OutputType> localApproximator; /** * Evaluator that implements the concept of LocallyWeightedLearning. That * is, given an input point, this function re-weights the dataset according * to how "close" the dataset inputs are to the given input. An inner-loop * learner then uses the re-weighted to compute a local function * approximator for this input. The output of this class is the output * of the local function approximator, which is recomputed each time * evaluate() is called. Thus, evaluate() on this method is relatively * expensive (because it calls learn() on the given BatchLearner) * * @param kernel * Kernel that provides the weights between an input and each sample * in the input dataset * @param rawData * Original (weighted) dataset * @param learner * Learner that takes the Collection of WeightedInputOutputPairs from * the Kernel reweighting and creates a local function approximation at * the given input. I would strongly recommend using fast or closed-form * learners for this. */ public LocallyWeightedFunction( Kernel<? super InputType> kernel, Collection<? extends InputOutputPair<? extends InputType, OutputType>> rawData, SupervisedBatchLearner<InputType,OutputType,?> learner ) { this.setKernel( kernel ); ArrayList<WeightedInputOutputPair<InputType,OutputType>> weightedRawData = new ArrayList<WeightedInputOutputPair<InputType,OutputType>>( rawData.size() ); this.locallyWeightedData = new ArrayList<DefaultWeightedInputOutputPair<InputType, OutputType>>( rawData.size() ); for (InputOutputPair<? extends InputType, ? extends OutputType> pair : rawData) { double weight = DatasetUtil.getWeight(pair); // Note that these have to be different instances of // WeightedInputOutputPair because we'll be overwriting the weight // in locallyWeightedData and we don't want to blow away the weight // in the original data weightedRawData.add( new DefaultWeightedInputOutputPair<InputType, OutputType>( pair, weight ) ); this.locallyWeightedData.add( new DefaultWeightedInputOutputPair<InputType, OutputType>( pair, weight ) ); } this.rawData = weightedRawData; this.setLearner( learner ); this.setLocalApproximator( null ); } /** * Getter for kernel * @return * Kernel that provides the weights between an input and each sample * in the input dataset */ public Kernel<? super InputType> getKernel() { return this.kernel; } /** * Setter for kernel * @param kernel * Kernel that provides the weights between an input and each sample * in the input dataset */ public void setKernel( Kernel<? super InputType> kernel ) { this.kernel = kernel; } /** * Getter for learner * @return * Learner that takes the Collection of WeightedInputOutputPairs from * the Kernel reweighting and creates a local function approximation at * the given input. I would strongly recommend using fast or closed-form * learners for this. */ public SupervisedBatchLearner<InputType,OutputType,?> getLearner() { return this.learner; } /** * Setter for learner * @param learner * Learner that takes the Collection of WeightedInputOutputPairs from * the Kernel reweighting and creates a local function approximation at * the given input. I would strongly recommend using fast or closed-form * learners for this. */ public void setLearner( SupervisedBatchLearner<InputType,OutputType,?> learner ) { this.learner = learner; } /** * This function re-weights the dataset according to the Kernel value * between the input and each input in the dataset. This re-weighted * dataset is then given to the learner to create a local approximator * that then evaluate the input to produce the prediction * @param input * Input to create a local approximator for, using the Kernel to weight * the original dataset * @return * Approximation at the given input using the Kernel weights, the original * (weighted) dataset, and the BatchLearner */ @Override public OutputType evaluate( InputType input ) { // Re-weight the samples according to the kernel weight times the // original sample weight, then run the learner on this locally // weighted dataset for (int i = 0; i < this.rawData.size(); i++) { WeightedInputOutputPair<? extends InputType, ? extends OutputType> originalPair = this.rawData.get( i ); DefaultWeightedInputOutputPair<InputType, OutputType> locallyWeightedPair = this.locallyWeightedData.get( i ); double kernelWeight = this.kernel.evaluate( input, originalPair.getInput() ); double originalWeight = originalPair.getWeight(); double localWeight = kernelWeight * originalWeight; locallyWeightedPair.setWeight( localWeight ); } this.localApproximator = this.learner.learn( this.locallyWeightedData ); return this.localApproximator.evaluate( input ); } /** * Getter for localApproximator * @return * Dataset containing the weights in response to an evaluate() call. The * weights in this dataset will be a product of the original dataset * weights times the weights from the Kernel response to the given input * to the evaluate() method call. */ public Evaluator<? super InputType, ? extends OutputType> getLocalApproximator() { return this.localApproximator; } /** * Setter for localApproximator * @param localApproximator * Dataset containing the weights in response to an evaluate() call. The * weights in this dataset will be a product of the original dataset * weights times the weights from the Kernel response to the given input * to the evaluate() method call. */ public void setLocalApproximator( Evaluator<? super InputType, ? extends OutputType> localApproximator ) { this.localApproximator = localApproximator; } /** * Learning algorithm for creating LocallyWeightedFunctions. This is * essentially just a pass through, as no learning takes place, but a * model is fitted to the data about each point on an evaluate() call * @param <InputType> * Input class to map onto the Output class * @param <OutputType> * Output of the Evaluator */ public static class Learner<InputType, OutputType> extends AbstractCloneableSerializable implements SupervisedBatchLearner<InputType,OutputType,LocallyWeightedFunction<? super InputType,OutputType>> { /** * Kernel that provides the weights between an input and each sample * in the input dataset */ private Kernel<? super InputType> kernel; /** * Learner that takes the Collection of WeightedInputOutputPairs from * the Kernel reweighting and creates a local function approximation at * the given input. I would strongly recommend using fast or closed-form * learners for this. */ private SupervisedBatchLearner<InputType,OutputType,?> learner; /** * Creates a new instance of LocallyWeightedFunction * @param kernel * Kernel that provides the weights between an input and each sample * in the input dataset * @param learner * Learner that takes the Collection of WeightedInputOutputPairs from * the Kernel reweighting and creates a local function approximation at * the given input. I would strongly recommend using fast or closed-form * learners for this. */ public Learner( Kernel<? super InputType> kernel, SupervisedBatchLearner<InputType,OutputType,?> learner ) { this.setKernel( kernel ); this.setLearner( learner ); } @Override public LocallyWeightedFunction<InputType, OutputType> learn( Collection<? extends InputOutputPair<? extends InputType, OutputType>> data ) { return new LocallyWeightedFunction<InputType, OutputType>( this.getKernel(), data, this.getLearner() ); } /** * Getter for kernel * @return * Kernel that provides the weights between an input and each sample * in the input dataset */ public Kernel<? super InputType> getKernel() { return this.kernel; } /** * Setter for kernel * @param kernel * Kernel that provides the weights between an input and each sample * in the input dataset */ public void setKernel( Kernel<? super InputType> kernel ) { this.kernel = kernel; } /** * Getter for learner * @return * Learner that takes the Collection of WeightedInputOutputPairs from * the Kernel reweighting and creates a local function approximation at * the given input. I would strongly recommend using fast or closed-form * learners for this. */ public SupervisedBatchLearner<InputType,OutputType, ?> getLearner() { return this.learner; } /** * Setter for learner * @param learner The learner to use */ public void setLearner( SupervisedBatchLearner<InputType,OutputType,?> learner ) { this.learner = learner; } } }