/*
* File: MaximumLikelihoodDistributionEstimator.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Jul 12, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.statistics.method;
import gov.sandia.cognition.algorithm.AbstractParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizerDirectionSetPowell;
import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizerNelderMead;
import gov.sandia.cognition.learning.function.cost.ParallelNegativeLogLikelihood;
import gov.sandia.cognition.statistics.ClosedFormComputableDistribution;
import gov.sandia.cognition.statistics.ClosedFormDiscreteUnivariateDistribution;
import gov.sandia.cognition.statistics.DistributionEstimator;
import gov.sandia.cognition.statistics.EstimableDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.statistics.SmoothUnivariateDistribution;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Pair;
import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Enumeration;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.Callable;
/**
* Estimates the most-likely distribution, and corresponding parameters, of
* that generated the given data from a pre-determined collection of
* candidate parameteric distributions.
* @param <DataType> Type of data generated by the distributions.
* @author Kevin R. Dixon
* @since 3.1
*/
public class MaximumLikelihoodDistributionEstimator<DataType>
extends AbstractParallelAlgorithm
implements BatchLearner<Collection<? extends DataType>,ClosedFormComputableDistribution<DataType>>
{
/**
* Collection of Distributions to estimate the optimal parameters of
*/
private Collection<? extends ClosedFormComputableDistribution<DataType>> distributions;
/**
* Creates a new instance of MaximumLikelihoodDistributionEstimator
*/
public MaximumLikelihoodDistributionEstimator()
{
this( null );
}
/**
* Creates a new instance of MaximumLikelihoodDistributionEstimator
* @param distributions
* Collection of Distributions to estimate the optimal parameters of
*/
public MaximumLikelihoodDistributionEstimator(
Collection<? extends ClosedFormComputableDistribution<DataType>> distributions )
{
this.setDistributions(distributions);
}
@Override
public MaximumLikelihoodDistributionEstimator<DataType> clone()
{
@SuppressWarnings("unchecked")
MaximumLikelihoodDistributionEstimator<DataType> clone =
(MaximumLikelihoodDistributionEstimator<DataType>) super.clone();
clone.setDistributions( ObjectUtil.cloneSmartElementsAsArrayList(
this.getDistributions() ) );
return clone;
}
/**
* Getter for distributions
* @return
* Collection of Distributions to estimate the optimal parameters of
*/
public Collection<? extends ClosedFormComputableDistribution<DataType>> getDistributions()
{
return this.distributions;
}
/**
* Setter for distributions
* @param distributions
* Collection of Distributions to estimate the optimal parameters of
*/
public void setDistributions(
Collection<? extends ClosedFormComputableDistribution<DataType>> distributions)
{
this.distributions = distributions;
}
@SuppressWarnings("unchecked")
public ClosedFormComputableDistribution<DataType> learn(
Collection<? extends DataType> data)
{
ArrayList<DistributionEstimationTask<DataType>> tasks =
new ArrayList<DistributionEstimationTask<DataType>>( this.distributions.size() );
for( ClosedFormComputableDistribution<DataType> distribution : this.getDistributions() )
{
tasks.add( new DistributionEstimationTask<DataType>(
(ClosedFormComputableDistribution<DataType>) distribution.clone(), data ) );
}
ArrayList<Pair<Double,ClosedFormComputableDistribution<DataType>>> results;
try
{
results = ParallelUtil.executeInParallel(tasks,this.getThreadPool());
}
catch (Exception e)
{
throw new RuntimeException(e);
}
double minCost = Double.POSITIVE_INFINITY;
ClosedFormComputableDistribution<DataType> minDistribution = null;
for( Pair<Double,ClosedFormComputableDistribution<DataType>> result : results )
{
double cost = result.getFirst();
if( minCost > cost )
{
minCost = cost;
minDistribution = result.getSecond();
}
}
return minDistribution;
}
/**
* Estimates the optimal parameters of a single distribution
* @param <DataType>
* Type of data emitted by the distribution
*/
public static class DistributionEstimationTask<DataType>
extends AbstractCloneableSerializable
implements Callable<Pair<Double,ClosedFormComputableDistribution<DataType>>>
{
/**
* Distribution to estimate
*/
ClosedFormComputableDistribution<DataType> distribution;
/**
* Data to use in the estimation
*/
Collection<? extends DataType> data;
/**
* Creates a new instance of DistributionEstimationTask
* @param distribution
* Distribution to estimate
* @param data
* Data to use in the estimation
*/
public DistributionEstimationTask(
ClosedFormComputableDistribution<DataType> distribution,
Collection<? extends DataType> data)
{
this.distribution = distribution;
this.data = data;
}
@SuppressWarnings("unchecked")
public Pair<Double,ClosedFormComputableDistribution<DataType>> call()
throws Exception
{
try
{
ParallelNegativeLogLikelihood<DataType> costFunction =
new ParallelNegativeLogLikelihood<DataType>(this.data);
// final int N = this.data.size();
//// final double tolerance = LineBracketInterpolatorBrent.DEFAULT_TOLERANCE / N;
// final double tolerance = 1e-100;
// LineBracketInterpolatorBrent brent = new LineBracketInterpolatorBrent();
// brent.setTolerance(tolerance);
// brent.getGoldenInterpolator().setTolerance(tolerance);
// brent.getParabolicInterpolator().setTolerance(tolerance);
// LineMinimizerDerivativeFree liner = new LineMinimizerDerivativeFree( brent );
// liner.setTolerance(tolerance);
// FunctionMinimizerDirectionSetPowell minimizer =
//// new FunctionMinimizerDirectionSetPowell();
// new FunctionMinimizerDirectionSetPowell( liner );
// minimizer.setTolerance(tolerance);
// See if the initial parameterization is "in the ballpark"
ClosedFormComputableDistribution<DataType> result1 =
ObjectUtil.cloneSafe( this.distribution );
double cost1 = costFunction.evaluate(result1);
// System.out.println( "Initial Cost: " + cost1 + ", Class: " + result1.getClass().getCanonicalName() + ", Parameters: " + result1.convertToVector() );
// The initial parameters don't work, so guess some more
if( Double.isInfinite(cost1) || Double.isNaN(cost1) )
{
ClosedFormComputableDistribution<DataType> result2 =
ObjectUtil.cloneSafe( this.distribution );
boolean bruteForce = true;
int Nsub = Math.min( 1000, this.data.size()/1000 );
// int Nsub = (int) Math.ceil( this.data.size() / 1000 );
Collection<? extends DataType> subList =
CollectionUtil.asArrayList(this.data).subList(0, Nsub);
// We've got a closed-form estimator... use that next
if( this.distribution instanceof EstimableDistribution )
{
DistributionEstimator<DataType,ClosedFormComputableDistribution<DataType>> solver =
((EstimableDistribution) this.distribution).getEstimator();
try
{
result2 = solver.learn( this.data );
double cost2 = costFunction.evaluate(result2);
// System.out.println( "Solver Cost: " + cost2 + ", Class: " + result2.getClass().getCanonicalName() + ", Parameters: " + result2.convertToVector() );
bruteForce = (Double.isInfinite(cost2) || Double.isNaN(cost2));
}
catch (Exception e)
{
// System.out.println( "Solver barfed: " + solver.getClass().getCanonicalName() + ", Exception: " + e );
bruteForce = true;
result2 = ObjectUtil.cloneSafe(this.distribution);
}
if( bruteForce )
{
try
{
result2 = solver.learn( subList );
double cost2 = costFunction.evaluate(result2);
// System.out.println( "Sub-Solver Cost: " + cost2 + ", Class: " + result2.getClass().getCanonicalName() + ", Parameters: " + result2.convertToVector() );
bruteForce = (Double.isInfinite(cost2) || Double.isNaN(cost2));
}
catch (Exception e)
{
// System.out.println( "Sub-Solver barfed: " + solver.getClass().getCanonicalName() + ", Exception: " + e );
result2 = ObjectUtil.cloneSafe(this.distribution);
}
}
}
// Nothing has worked so far, Use Nelder-Mead, which is
// slow but is less susceptible to numerical imprecision
if( bruteForce )
{
FunctionMinimizerNelderMead minimizer1 =
new FunctionMinimizerNelderMead();
minimizer1.setMaxIterations(10);
minimizer1.setTolerance(1.0);
DistributionParameterEstimator<DataType,ClosedFormComputableDistribution<DataType>> estimator2 =
new DistributionParameterEstimator<DataType, ClosedFormComputableDistribution<DataType>>(
ObjectUtil.cloneSafe(result2), costFunction, minimizer1 );
result2 = estimator2.learn(this.data);
double cost2 = costFunction.evaluate(result2);
// System.out.println( "Brute Cost: " + cost2 + ", Class: " + result2.getClass().getCanonicalName() + ", Parameters: " + result2.convertToVector() );
// Damn.. nothing has worked so far... subsample the
// data and re-estimate.
if( Double.isInfinite(cost2) || Double.isNaN(cost2) )
{
minimizer1.setMaxIterations(1000);
costFunction.setCostParameters(subList);
estimator2 = new DistributionParameterEstimator<DataType, ClosedFormComputableDistribution<DataType>>(
ObjectUtil.cloneSafe(result2), costFunction, minimizer1 );
result2 = estimator2.learn(subList);
costFunction.setCostParameters(this.data);
double cost3 = costFunction.evaluate(result2);
// System.out.println( "Subsample Cost: " + cost3 + ", Class: " + result2.getClass().getCanonicalName() + ", Parameters: " + result2.convertToVector() );
}
}
result1 = result2;
}
FunctionMinimizerDirectionSetPowell minimizer3 =
new FunctionMinimizerDirectionSetPowell();
DistributionParameterEstimator<DataType,ClosedFormComputableDistribution<DataType>> estimator3 =
new DistributionParameterEstimator<DataType, ClosedFormComputableDistribution<DataType>>(
ObjectUtil.cloneSafe(result1), costFunction, minimizer3 );
ClosedFormComputableDistribution<DataType> result3 =
estimator3.learn(this.data);
double cost3 = costFunction.evaluate(result3);
// System.out.println( "Final Cost: " + cost3 + ", Class: " + result3.getClass().getCanonicalName() + ", Parameters: " + result3.convertToVector() );
return DefaultPair.create( cost3, result3 );
}
catch (Exception e)
{
// System.out.println( this.distribution.getClass().getCanonicalName() + " barfed: " + e );
// e.printStackTrace();
return DefaultPair.create( Double.POSITIVE_INFINITY, (ClosedFormComputableDistribution<DataType>) this.distribution.clone() );
}
}
}
/**
* Estimates a continuous distribution.
*
* @param data
* The data to estimate a distribution for.
* @return
* The estimated distribution.
* @throws Exception
* If there is an error in the estimation.
*/
public static SmoothUnivariateDistribution estimateContinuousDistribution(
Collection<Double> data )
throws Exception
{
LinkedList<SmoothUnivariateDistribution> distributions =
getDistributionClasses( SmoothUnivariateDistribution.class );
MaximumLikelihoodDistributionEstimator<Double> estimator =
new MaximumLikelihoodDistributionEstimator<Double>( distributions );
return (SmoothUnivariateDistribution) estimator.learn(data);
}
/**
* Estimates a discrete distribution.
*
* @param data
* The data to estimate a distribution for.
* @return
* The estimated distribution.
* @throws Exception
* If there is an error in the estimation.
*/
@SuppressWarnings(value={"unchecked", "rawtypes"})
public static ClosedFormDiscreteUnivariateDistribution estimateDiscreteDistribution(
Collection<? extends Number> data )
throws Exception
{
LinkedList<ClosedFormDiscreteUnivariateDistribution> distributions =
getDistributionClasses( ClosedFormDiscreteUnivariateDistribution.class );
MaximumLikelihoodDistributionEstimator<Number> estimator =
new MaximumLikelihoodDistributionEstimator(
(Collection<? extends ClosedFormComputableDistribution>) distributions);
return (ClosedFormDiscreteUnivariateDistribution) estimator.learn(data);
}
/**
* Gets the distribution classes for the given base distribution.
*
* @param <DistributionType>
* The type of distribution.
* @param baseDistribution
* The class of the base distribution.
* @return
* The list of implementations of that distribution in the statistics
* distribution package.
* @throws ClassNotFoundException
* @throws IOException
* @throws InstantiationException
* @throws IllegalAccessException
*/
@SuppressWarnings("unchecked")
protected static <DistributionType extends ClosedFormComputableDistribution<?>> LinkedList<DistributionType> getDistributionClasses(
Class<? extends DistributionType> baseDistribution )
throws ClassNotFoundException, IOException, InstantiationException, IllegalAccessException
{
UnivariateGaussian g = new UnivariateGaussian();
Package p = g.getClass().getPackage();
LinkedList<Class<?>> cs = getClasses( p.getName() );
LinkedList<DistributionType> instances =
new LinkedList<DistributionType>();
for( Class<?> c : cs )
{
if( baseDistribution.isAssignableFrom( c ) )
{
if( ProbabilityFunction.class.isAssignableFrom(c) )
{
try
{
instances.add( (DistributionType) c.newInstance());
}
catch (Exception e)
{
// System.out.println( "Couldn't instantiate: " + c.getCanonicalName() );
}
}
}
}
return instances;
}
/**
* Scans all classes accessible from the context class loader which belong to the given package and subpackages.
*
* @param packageName The base package
* @return The classes
* @throws ClassNotFoundException
* @throws IOException
*/
private static LinkedList<Class<?>> getClasses(
String packageName)
throws ClassNotFoundException, IOException
{
ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
assert classLoader != null;
String path = packageName.replace('.', '/');
Enumeration<URL> resources = classLoader.getResources(path);
List<File> dirs = new ArrayList<File>();
while (resources.hasMoreElements())
{
URL resource = resources.nextElement();
dirs.add(new File(resource.getFile()));
}
LinkedList<Class<?>> classes = new LinkedList<Class<?>>();
for (File directory : dirs)
{
classes.addAll(findClasses(directory, packageName));
}
return classes;
}
/**
* Recursive method used to find all classes in a given directory and subdirs.
*
* @param directory The base directory
* @param packageName The package name for classes found inside the base directory
* @return The classes
* @throws ClassNotFoundException
*/
private static LinkedList<Class<?>> findClasses(
File directory,
String packageName)
throws ClassNotFoundException
{
LinkedList<Class<?>> classes = new LinkedList<Class<?>>();
File[] files = directory.listFiles();
for (File file : files)
{
if (file.getName().endsWith(".class"))
{
classes.add(Class.forName(packageName + '.' + file.getName().substring(0, file.getName().length() - 6)));
}
}
return classes;
}
}