/* * File: DistributionParameterEstimator.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright Jul 8, 2010, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.statistics.method; import gov.sandia.cognition.algorithm.AnytimeAlgorithmWrapper; import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm; import gov.sandia.cognition.evaluator.Evaluator; import gov.sandia.cognition.learning.algorithm.BatchLearner; import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizer; import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizerBFGS; import gov.sandia.cognition.learning.function.cost.CostFunction; import gov.sandia.cognition.math.DifferentiableEvaluator; import gov.sandia.cognition.math.matrix.NumericalDifferentiator; import gov.sandia.cognition.math.matrix.Vector; import gov.sandia.cognition.statistics.ClosedFormDistribution; import gov.sandia.cognition.statistics.method.DistributionParameterEstimator.DistributionWrapper; import gov.sandia.cognition.util.AbstractCloneableSerializable; import gov.sandia.cognition.util.DefaultNamedValue; import gov.sandia.cognition.util.NamedValue; import gov.sandia.cognition.util.ObjectUtil; import java.util.Collection; /** * A method of estimating the parameters of a distribution using an arbitrary * CostFunction and FunctionMinimizer algorithm. * @param <DataType> * Type of data generated by the distribution * @param <DistributionType> * Type of distribution to estimate the parameters of. * @author Kevin R. Dixon * @since 3.1 */ public class DistributionParameterEstimator<DataType, DistributionType extends ClosedFormDistribution<? extends DataType>> extends AnytimeAlgorithmWrapper<DistributionType, FunctionMinimizer<Vector,Double,? super DistributionParameterEstimator<DataType,DistributionType>.DistributionWrapper>> implements BatchLearner<Collection<? extends DataType>,DistributionType>, MeasurablePerformanceAlgorithm { /** * Function that maps a Distribution onto a Vector/Scalar function. */ private DistributionWrapper distributionWrapper; /** * Distribution that minimizes the cost function. */ private DistributionType result; /** * Creates a new instance of DistributionParameterEstimator * @param distribution * Distribution to estimate the parameters of * @param costFunction * Cost function to use in the minimization procedure */ public DistributionParameterEstimator( DistributionType distribution, CostFunction<? super DistributionType,Collection<? extends DataType>> costFunction ) { this( distribution, costFunction, new FunctionMinimizerBFGS() ); } /** * Creates a new instance of DistributionParameterEstimator * @param distribution * Distribution to estimate the parameters of * @param costFunction * Cost function to use in the minimization procedure * @param algorithm * Minimization algorithm to use, such as FunctionMinimizerBFGS, * FunctionMinimizerDirectionSetPowell, etc. */ public DistributionParameterEstimator( DistributionType distribution, CostFunction<? super DistributionType,Collection<? extends DataType>> costFunction, FunctionMinimizer<Vector,Double,? super DistributionParameterEstimator<DataType,DistributionType>.DistributionWrapper> algorithm ) { super( algorithm ); this.distributionWrapper = new DistributionWrapper( distribution, costFunction ); } @Override public DistributionParameterEstimator<DataType,DistributionType> clone() { @SuppressWarnings("unchecked") DistributionParameterEstimator<DataType,DistributionType> clone = (DistributionParameterEstimator<DataType,DistributionType>) super.clone(); clone.distributionWrapper = ObjectUtil.cloneSafe( this.distributionWrapper ); clone.result = ObjectUtil.cloneSafe( this.getResult() ); return clone; } public DistributionType learn( Collection<? extends DataType> minimizationParameters) { DistributionWrapper wrapperClone = this.distributionWrapper.clone(); wrapperClone.costFunction.setCostParameters(minimizationParameters); this.getAlgorithm().setInitialGuess( wrapperClone.distribution.convertToVector() ); this.getAlgorithm().learn( wrapperClone ); this.result = wrapperClone.distribution; return this.getResult(); } public DistributionType getResult() { return this.result; } public NamedValue<? extends Number> getPerformance() { double cost = (this.getAlgorithm().getResult() == null) ? 0.0 : this.getAlgorithm().getResult().getOutput(); return new DefaultNamedValue<Double>( "Cost", cost ); } /** * Maps the parameters of a Distribution and a CostFunction into a * Vector/Double Evaluator. */ protected class DistributionWrapper extends AbstractCloneableSerializable implements Evaluator<Vector,Double>, DifferentiableEvaluator<Vector, Double, Vector> { /** * Distribution to estimate the parameters of */ protected DistributionType distribution; /** * Cost function to use in the minimization procedure */ protected CostFunction<? super DistributionType, ? super Collection<? extends DataType>> costFunction; /** * Creates a new instance of DistributionWrapper * @param distribution * Distribution to estimate the parameters of * @param costFunction * Cost function to use in the minimization procedure */ public DistributionWrapper( DistributionType distribution, CostFunction<? super DistributionType, ? super Collection<? extends DataType>> costFunction) { this.distribution = distribution; this.costFunction = costFunction; } @Override public DistributionWrapper clone() { @SuppressWarnings("unchecked") DistributionWrapper clone = (DistributionWrapper) super.clone(); clone.distribution = ObjectUtil.cloneSafe( this.distribution ); clone.costFunction = ObjectUtil.cloneSafe( this.costFunction ); return clone; } public Double evaluate( Vector input) { try { distribution.convertFromVector(input); return this.costFunction.evaluate(this.distribution); } catch (Exception e) { // Leave the distribution unchanged... // return this.costFunction.evaluate(this.distribution); return Double.POSITIVE_INFINITY; // return Double.MAX_VALUE; } } public Vector differentiate( Vector input) { return NumericalDifferentiator.VectorJacobian.differentiate(input,this); } } }