/* * File: LinearRegression.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright September 5, 2007, Sandia Corporation. Under the terms of Contract * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by * or on behalf of the U.S. Government. Export of this program may require a * license from the United States Government. See CopyrightHistory.txt for * complete details. * */ package gov.sandia.cognition.learning.algorithm.regression; import gov.sandia.cognition.annotation.CodeReview; import gov.sandia.cognition.annotation.PublicationReference; import gov.sandia.cognition.annotation.PublicationType; import gov.sandia.cognition.learning.data.InputOutputPair; import gov.sandia.cognition.math.matrix.Vector; import gov.sandia.cognition.evaluator.Evaluator; import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner; import gov.sandia.cognition.learning.data.DatasetUtil; import gov.sandia.cognition.learning.data.DefaultWeightedInputOutputPair; import gov.sandia.cognition.learning.data.WeightedInputOutputPair; import gov.sandia.cognition.learning.function.scalar.LinearDiscriminant; import gov.sandia.cognition.learning.function.scalar.VectorFunctionLinearDiscriminant; import gov.sandia.cognition.learning.function.vector.ScalarBasisSet; import gov.sandia.cognition.util.AbstractCloneableSerializable; import gov.sandia.cognition.util.ObjectUtil; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; /** * Computes the least-squares regression for a LinearCombinationFunction * given a dataset. A LinearCombinationFunction is a weighted linear * combination of (potentially) nonlinear basis functions. This looks like * y(x) = a0*f0(x) + a1*f1(x) + ... + an*fn(x) and so forth. * The internal class LinearRegression.Statistic returns the goodness-of-fit * statistics for a set of target-estimate pairs, include a p-value for the * null hypothesis significance. * * @param <InputType> Input class for the basis functions, for example, Double, * Vector, String. * @author Kevin R. Dixon * @since 2.0 */ @CodeReview( reviewer="Kevin R. Dixon", date="2008-09-02", changesNeeded=false, comments={ "Made minor changes to javadoc", "Looks fine." } ) @PublicationReference( author="Wikipedia", title="Linear regression", type=PublicationType.WebPage, year=2008, url="http://en.wikipedia.org/wiki/Linear_regression" ) public class LinearBasisRegression<InputType> extends AbstractCloneableSerializable implements SupervisedBatchLearner<InputType, Double, VectorFunctionLinearDiscriminant<InputType>> { /** * Tolerance for the pseudo inverse in the learn method, {@value}. */ public static final double DEFAULT_PSEUDO_INVERSE_TOLERANCE = 1e-10; /** * Function that maps the InputType to a Vector */ private Evaluator<? super InputType, Vector> inputToVectorMap; /** * Flag to use a pseudoinverse. True to use the expensive, but more * accurate, pseudoinverse routine. False uses a very fast, but * numerically less stable LU solver. Default value is "true". */ private boolean usePseudoInverse; /** * Creates a new instance of LinearRegression * @param basisFunctions * Basis functions to create the ScalarBasisSet from */ public LinearBasisRegression( Collection<? extends Evaluator<? super InputType, Double>> basisFunctions ) { this( new ScalarBasisSet<InputType>( basisFunctions ) ); } /** * Creates a new instance of LinearRegression * @param inputToVectorMap * Function that maps the InputType to a Vector */ public LinearBasisRegression( ScalarBasisSet<InputType> inputToVectorMap ) { this( (Evaluator<? super InputType, Vector>) inputToVectorMap ); } /** * Creates a new instance of LinearRegression * @param inputToVectorMap * Function that maps the InputType to a Vector */ public LinearBasisRegression( Evaluator<? super InputType, Vector> inputToVectorMap ) { this.setInputToVectorMap( inputToVectorMap ); this.setUsePseudoInverse( true ); } @Override public LinearBasisRegression<InputType> clone() { @SuppressWarnings("unchecked") LinearBasisRegression<InputType> clone = (LinearBasisRegression<InputType>) super.clone(); clone.setInputToVectorMap( ObjectUtil.cloneSmart( this.getInputToVectorMap() ) ); return clone; } /** * Computes the linear regression for the given Collection of * InputOutputPairs. The inputs of the pairs is the independent variable, * and the pair output is the dependent variable (variable to predict). * The pairs can have an associated weight to bias the regression equation. * @param data * Collection of InputOutputPairs for the variables. Can be * WeightedInputOutputPairs. * @return * LinearCombinationFunction that minimizes the RMS error of the outputs. */ @Override public VectorFunctionLinearDiscriminant<InputType> learn( Collection<? extends InputOutputPair<? extends InputType, Double>> data ) { // Create the vector-based dataset first ArrayList<WeightedInputOutputPair<Vector,Double>> vectorData = new ArrayList<WeightedInputOutputPair<Vector, Double>>( data.size() ); for (InputOutputPair<? extends InputType, Double> pair : data) { double weight = DatasetUtil.getWeight(pair); Vector xrow = this.inputToVectorMap.evaluate( pair.getInput() ); Double output = pair.getOutput(); vectorData.add( DefaultWeightedInputOutputPair.create( xrow, output, weight ) ); } LinearRegression linear = new LinearRegression(); linear.setUsePseudoInverse(this.getUsePseudoInverse()); LinearDiscriminant weights = linear.learn(vectorData); return new VectorFunctionLinearDiscriminant<InputType>( this.inputToVectorMap, weights ); } /** * Getter for inputToVectorMap * @return * Function that maps the InputType to a Vector */ public Evaluator<? super InputType, Vector> getInputToVectorMap() { return this.inputToVectorMap; } /** * Setter for inputToVectorMap * @param inputToVectorMap * Function that maps the InputType to a Vector */ public void setInputToVectorMap( Evaluator<? super InputType, Vector> inputToVectorMap ) { this.inputToVectorMap = inputToVectorMap; } /** * Getter for usePseudoInverse * @return * Flag to use a pseudoinverse. True to use the expensive, but more * accurate, pseudoinverse routine. False uses a very fast, but * numerically less stable LU solver. Default value is "true". */ public boolean getUsePseudoInverse() { return this.usePseudoInverse; } /** * Setter for usePseudoInverse * @param usePseudoInverse * Flag to use a pseudoinverse. True to use the expensive, but more * accurate, pseudoinverse routine. False uses a very fast, but * numerically less stable LU solver. Default value is "true". */ public void setUsePseudoInverse( boolean usePseudoInverse ) { this.usePseudoInverse = usePseudoInverse; } }