/* * File: MeanSquaredError.java * Authors: Justin Basilico * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright September 26, 2007, Sandia Corporation. Under the terms of Contract * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by * or on behalf of the U.S. Government. Export of this program may require a * license from the United States Government. See CopyrightHistory.txt for * complete details. * * */ package gov.sandia.cognition.learning.performance; import gov.sandia.cognition.learning.data.TargetEstimatePair; import java.util.Collection; /** * The {@code MeanSquaredError} class implements the method for computing the * performance of a supervised learner for a scalar function by the mean squared * between the target and estimated outputs. * * @param <InputType> The type of the input to the evaluator to compute the * performance of. * @author Justin Basilico * @since 2.0 */ public class MeanSquaredErrorEvaluator<InputType> extends AbstractSupervisedPerformanceEvaluator<InputType, Double, Double, Double> { /** * Creates a new instance of MeanSquaredError. */ public MeanSquaredErrorEvaluator() { super(); } /** * {@inheritDoc} * * @param data {@inheritDoc} * @return {@inheritDoc} */ public Double evaluatePerformance( final Collection<? extends TargetEstimatePair<? extends Double,? extends Double>> data ) { return MeanSquaredErrorEvaluator.compute( data ); } /** * Computes the mean squared error for the given pairs of values. The * squared difference between the two values in each pair is computed and * then the mean over all the values is returned. * * @param data The data to compute the mean squared error over. * @return The mean squared error. */ public static double compute( final Collection<? extends TargetEstimatePair<? extends Double, ? extends Double>> data ) { // Since we compute the mean we need to know how many items there are. final int count = data.size(); if (count <= 0) { // There must be at least one item to compute a mean. return 0.0; } // Compute the error for each pair and add it to the sum. double errorSum = 0.0; for (TargetEstimatePair<? extends Double, ? extends Double> pair : data) { final double target = pair.getTarget(); final double estimate = pair.getEstimate(); final double difference = target - estimate; // The error is the squared difference. final double error = difference * difference; errorSum += error; } // Compute the mean of the error sum. return errorSum / (double) count; } }