/*
* File: SumSquaredErrorCostFunction.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Jul 4, 2008, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.function.cost;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.data.TargetEstimatePair;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultPair;
import java.util.Collection;
/**
* This is the sum-squared error cost function
* @author Kevin R. Dixon
* @since 2.1
*/
public class SumSquaredErrorCostFunction
extends AbstractParallelizableCostFunction
{
/**
* Creates a new instance of SumSquaredErrorCostFunction
*/
public SumSquaredErrorCostFunction()
{
this( (Collection<? extends InputOutputPair<? extends Vector, Vector>>) null );
}
/**
* Creates a new instance of MeanSquaredErrorCostFunction
*
* @param dataset The dataset of examples to use to compute the error.
*/
public SumSquaredErrorCostFunction(
Collection<? extends InputOutputPair<? extends Vector, Vector>> dataset )
{
super( dataset );
}
@Override
public SumSquaredErrorCostFunction clone()
{
return (SumSquaredErrorCostFunction) super.clone();
}
public Object evaluatePartial(
Evaluator<? super Vector, ? extends Vector> evaluator )
{
double sumSquaredError = 0.0;
double weightSum = 0.0;
for (InputOutputPair<? extends Vector,Vector> pair : this.getCostParameters() )
{
// Compute the error vector.
Vector target = pair.getOutput();
Vector estimate = evaluator.evaluate( pair.getInput() );
double errorSquared = target.euclideanDistanceSquared( estimate );
double weight = DatasetUtil.getWeight(pair);
weightSum += weight;
sumSquaredError += weight * errorSquared;
}
weightSum *= 2.0;
return new EvaluatePartialSSE( sumSquaredError, weightSum );
}
public Double evaluateAmalgamate(
Collection<Object> partialResults )
{
double numerator = 0.0;
double denominator = 0.0;
for( Object result : partialResults )
{
EvaluatePartialSSE sse = (EvaluatePartialSSE) result;
numerator += sse.getFirst();
denominator += sse.getSecond();
}
if( denominator == 0.0 )
{
return 0.0;
}
else
{
return numerator / denominator;
}
}
public Object computeParameterGradientPartial(
GradientDescendable function )
{
RingAccumulator<Vector> parameterDelta =
new RingAccumulator<Vector>();
double denominator = 0.0;
for (InputOutputPair<? extends Vector, ? extends Vector> pair : this.getCostParameters())
{
Vector input = pair.getInput();
Vector target = pair.getOutput();
Vector negativeError = function.evaluate( input );
negativeError.minusEquals( target );
double weight = DatasetUtil.getWeight(pair);
if (weight != 1.0)
{
negativeError.scaleEquals( weight );
}
denominator += weight;
Matrix gradient = function.computeParameterGradient( input );
Vector parameterUpdate = negativeError.times( gradient );
parameterDelta.accumulate( parameterUpdate );
}
Vector negativeSum = parameterDelta.getSum();
return new GradientPartialSSE( negativeSum, denominator );
}
public Vector computeParameterGradientAmalgamate(
Collection<Object> partialResults )
{
RingAccumulator<Vector> numerator = new RingAccumulator<Vector>();
double denominator = 0.0;
for( Object result : partialResults )
{
GradientPartialSSE sse = (GradientPartialSSE) result;
numerator.accumulate( sse.getFirst() );
denominator += sse.getSecond();
}
Vector scaleSum = numerator.getSum();
if( denominator != 0.0 )
{
scaleSum.scaleEquals( 1.0 / (2.0*denominator) );
}
return scaleSum;
}
@Override
public Double evaluatePerformance(
Collection<? extends TargetEstimatePair<? extends Vector, ? extends Vector>> data )
{
double sumSquaredError = 0.0;
double weightSum = 0.0;
for (TargetEstimatePair<? extends Vector, ? extends Vector> pair : data)
{
// Compute the error vector.
Vector target = pair.getTarget();
Vector estimate = pair.getEstimate();
double errorSquared = target.euclideanDistanceSquared( estimate );
double weight = DatasetUtil.getWeight(pair);
weightSum += weight;
sumSquaredError += weight * errorSquared;
}
weightSum *= 2.0;
if( weightSum == 0.0 )
{
return 0.0;
}
else
{
return sumSquaredError / weightSum;
}
}
/**
* Caches often-used values for the Cost Function
*/
public static class Cache
extends AbstractCloneableSerializable
{
/**
* Jacobian
*/
public final Matrix J;
/**
* Inner-product of the Jacobian matrix: J.transpose().times( J )
*/
public final Matrix JtJ;
/**
* Jacobian transpose times Error: J.transpose().times( error )
*/
public final Vector Jte;
/**
* Cost-function value of the parameter set
*/
public final double parameterCost;
/**
* Creates a new instance of Cache
* @param J
* Jacobian
* @param JtJ
* Inner-product of the Jacobian matrix: J.transpose().times( J )
* @param Jte
* Jacobian transpose times Error: J.transpose().times( error )
* @param parameterCost
* Cost-function value of the parameter set
*/
protected Cache(
Matrix J,
Matrix JtJ,
Vector Jte,
double parameterCost )
{
this.J = J;
this.JtJ = JtJ;
this.Jte = Jte;
this.parameterCost = parameterCost;
}
/**
* Computes often-used parameters of a sum-squared error term
* @param objectToOptimize
* GradientDescendable to compute the statistics of
* @param data
* Dataset to consider
* @return
* Cache containing the cached cost-function parameters
*/
public static Cache compute(
GradientDescendable objectToOptimize,
Collection<? extends InputOutputPair<? extends Vector,Vector>> data )
{
RingAccumulator<Matrix> gradientAverage = new RingAccumulator<Matrix>();
RingAccumulator<Vector> gradientError = new RingAccumulator<Vector>();
// This is very close to the
// MeanSquaredErrorCostFunction.computeParameterGradient() method
double weightSum = 0.0;
double parameterCost = 0.0;
for (InputOutputPair<? extends Vector, ? extends Vector> pair : data)
{
// Compute the negativeError to save on Vector allocations
// (can't use pair.getOutput because we'll alter the dataset)
Vector negativeError = objectToOptimize.evaluate( pair.getInput() );
negativeError.minusEquals( pair.getOutput() );
double norm2 = negativeError.norm2Squared();
double weight = DatasetUtil.getWeight(pair);
if (weight != 1.0)
{
negativeError.scaleEquals( weight );
}
weightSum += weight;
parameterCost += norm2 * weight;
Matrix gradient =
objectToOptimize.computeParameterGradient( pair.getInput() );
gradientAverage.accumulate( gradient );
gradientError.accumulate( negativeError.times( gradient ) );
}
weightSum *= 2.0;
if( weightSum == 0.0 )
{
weightSum = 1.0;
}
// This is the Jacobian
Matrix J = gradientAverage.getSum();
J.scaleEquals( 1.0 / weightSum );
Matrix JtJ = J.transpose().times( J );
// Have to use 1.0 here because we've been accumulating the
// negativeError to save Vector allocations and the chain rule
// brings down the 2.0 from the exponent and we're already
// hitting the function with 0.5, so it's a wash.
Vector Jte = gradientError.getSum();
Jte.scaleEquals( 1.0 / weightSum );
// Make sure the cost is normalized by the weights
parameterCost /= weightSum;
return new Cache( J, JtJ, Jte, parameterCost );
}
}
/**
* Partial result from the SSE evaluate computation
*/
private static class EvaluatePartialSSE
extends DefaultPair<Double,Double>
{
/**
* Creates a new instance of EvaluatePartialSSE
* @param numerator
* Numerator
* @param denominator
* Denominator
*/
public EvaluatePartialSSE(
Double numerator,
Double denominator )
{
super( numerator, denominator );
}
}
/**
* Partial result from the SSE gradient computation
*/
public static class GradientPartialSSE
extends DefaultPair<Vector,Double>
{
/**
* Creates a new instance of GradientPartialSSE
* @param numerator
* Numerator
* @param denominator
* Denominator
*/
public GradientPartialSSE(
Vector numerator,
Double denominator )
{
super( numerator, denominator );
}
}
}