/*
* File: AbstractSupervisedCostFunction.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Dec 20, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.function.cost;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.DefaultWeightedTargetEstimatePair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.data.TargetEstimatePair;
import gov.sandia.cognition.learning.data.WeightedTargetEstimatePair;
import gov.sandia.cognition.learning.performance.AbstractSupervisedPerformanceEvaluator;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
/**
* Partial implementation of SupervisedCostFunction
* @param <InputType> Input type of the dataset and Evaluator
* @param <TargetType> Output type (labels) of the dataset and Evaluator
* @author Kevin R. Dixon
* @since 2.0
*/
public abstract class AbstractSupervisedCostFunction<InputType, TargetType>
extends AbstractSupervisedPerformanceEvaluator<InputType, TargetType, TargetType, Double>
implements SupervisedCostFunction<InputType, TargetType>
{
/**
* Labeled dataset to use to evaluate the cost against
*/
private Collection<? extends InputOutputPair<? extends InputType, TargetType>> costParameters;
/**
* Creates a new instance of AbstractSupervisedCostFunction
*/
public AbstractSupervisedCostFunction()
{
this.setCostParameters( null );
}
/**
* Creates a new instance of AbstractSupervisedCostFunction
* @param costParameters
* Labeled dataset to use to evaluate the cost against
*/
public AbstractSupervisedCostFunction(
Collection<? extends InputOutputPair<? extends InputType, TargetType>> costParameters )
{
this.setCostParameters( costParameters );
}
@Override
@SuppressWarnings("unchecked")
public AbstractSupervisedCostFunction<InputType, TargetType> clone()
{
AbstractSupervisedCostFunction<InputType, TargetType> clone =
(AbstractSupervisedCostFunction<InputType, TargetType>) super.clone();
clone.setCostParameters(
ObjectUtil.cloneSmartElementsAsArrayList(this.getCostParameters()) );
return clone;
}
@Override
public abstract Double evaluatePerformance(
Collection<? extends TargetEstimatePair<? extends TargetType, ? extends TargetType>> data );
public Double evaluate(
Evaluator<? super InputType, ? extends TargetType> evaluator )
{
ArrayList<WeightedTargetEstimatePair<TargetType, TargetType>> targetEstimatePairs =
new ArrayList<WeightedTargetEstimatePair<TargetType, TargetType>>( this.getCostParameters().size() );
for (InputOutputPair<? extends InputType, ? extends TargetType> io
: this.getCostParameters())
{
TargetType target = io.getOutput();
TargetType estimate = evaluator.evaluate(io.getInput());
targetEstimatePairs.add(DefaultWeightedTargetEstimatePair.create(
target, estimate, DatasetUtil.getWeight(io)));
}
return this.evaluatePerformance( targetEstimatePairs );
}
public Collection<? extends InputOutputPair<? extends InputType, TargetType>> getCostParameters()
{
return this.costParameters;
}
public void setCostParameters(
Collection<? extends InputOutputPair<? extends InputType, TargetType>> costParameters )
{
this.costParameters = costParameters;
}
@Override
public Double summarize(
Collection<? extends TargetEstimatePair<? extends TargetType, ? extends TargetType>> data )
{
return this.evaluatePerformance(data);
}
}