/* * File: MaximumLikelihoodDistributionEstimator.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright Jul 12, 2010, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.statistics.method; import gov.sandia.cognition.algorithm.AbstractParallelAlgorithm; import gov.sandia.cognition.algorithm.ParallelUtil; import gov.sandia.cognition.collection.CollectionUtil; import gov.sandia.cognition.learning.algorithm.BatchLearner; import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizerDirectionSetPowell; import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizerNelderMead; import gov.sandia.cognition.learning.function.cost.ParallelNegativeLogLikelihood; import gov.sandia.cognition.statistics.ClosedFormComputableDistribution; import gov.sandia.cognition.statistics.ClosedFormDiscreteUnivariateDistribution; import gov.sandia.cognition.statistics.DistributionEstimator; import gov.sandia.cognition.statistics.EstimableDistribution; import gov.sandia.cognition.statistics.ProbabilityFunction; import gov.sandia.cognition.statistics.SmoothUnivariateDistribution; import gov.sandia.cognition.statistics.distribution.UnivariateGaussian; import gov.sandia.cognition.util.AbstractCloneableSerializable; import gov.sandia.cognition.util.DefaultPair; import gov.sandia.cognition.util.ObjectUtil; import gov.sandia.cognition.util.Pair; import java.io.File; import java.io.IOException; import java.net.URL; import java.util.ArrayList; import java.util.Collection; import java.util.Enumeration; import java.util.LinkedList; import java.util.List; import java.util.concurrent.Callable; /** * Estimates the most-likely distribution, and corresponding parameters, of * that generated the given data from a pre-determined collection of * candidate parameteric distributions. * @param <DataType> Type of data generated by the distributions. * @author Kevin R. Dixon * @since 3.1 */ public class MaximumLikelihoodDistributionEstimator<DataType> extends AbstractParallelAlgorithm implements BatchLearner<Collection<? extends DataType>,ClosedFormComputableDistribution<DataType>> { /** * Collection of Distributions to estimate the optimal parameters of */ private Collection<? extends ClosedFormComputableDistribution<DataType>> distributions; /** * Creates a new instance of MaximumLikelihoodDistributionEstimator */ public MaximumLikelihoodDistributionEstimator() { this( null ); } /** * Creates a new instance of MaximumLikelihoodDistributionEstimator * @param distributions * Collection of Distributions to estimate the optimal parameters of */ public MaximumLikelihoodDistributionEstimator( Collection<? extends ClosedFormComputableDistribution<DataType>> distributions ) { this.setDistributions(distributions); } @Override public MaximumLikelihoodDistributionEstimator<DataType> clone() { @SuppressWarnings("unchecked") MaximumLikelihoodDistributionEstimator<DataType> clone = (MaximumLikelihoodDistributionEstimator<DataType>) super.clone(); clone.setDistributions( ObjectUtil.cloneSmartElementsAsArrayList( this.getDistributions() ) ); return clone; } /** * Getter for distributions * @return * Collection of Distributions to estimate the optimal parameters of */ public Collection<? extends ClosedFormComputableDistribution<DataType>> getDistributions() { return this.distributions; } /** * Setter for distributions * @param distributions * Collection of Distributions to estimate the optimal parameters of */ public void setDistributions( Collection<? extends ClosedFormComputableDistribution<DataType>> distributions) { this.distributions = distributions; } @SuppressWarnings("unchecked") public ClosedFormComputableDistribution<DataType> learn( Collection<? extends DataType> data) { ArrayList<DistributionEstimationTask<DataType>> tasks = new ArrayList<DistributionEstimationTask<DataType>>( this.distributions.size() ); for( ClosedFormComputableDistribution<DataType> distribution : this.getDistributions() ) { tasks.add( new DistributionEstimationTask<DataType>( (ClosedFormComputableDistribution<DataType>) distribution.clone(), data ) ); } ArrayList<Pair<Double,ClosedFormComputableDistribution<DataType>>> results; try { results = ParallelUtil.executeInParallel(tasks,this.getThreadPool()); } catch (Exception e) { throw new RuntimeException(e); } double minCost = Double.POSITIVE_INFINITY; ClosedFormComputableDistribution<DataType> minDistribution = null; for( Pair<Double,ClosedFormComputableDistribution<DataType>> result : results ) { double cost = result.getFirst(); if( minCost > cost ) { minCost = cost; minDistribution = result.getSecond(); } } return minDistribution; } /** * Estimates the optimal parameters of a single distribution * @param <DataType> * Type of data emitted by the distribution */ public static class DistributionEstimationTask<DataType> extends AbstractCloneableSerializable implements Callable<Pair<Double,ClosedFormComputableDistribution<DataType>>> { /** * Distribution to estimate */ ClosedFormComputableDistribution<DataType> distribution; /** * Data to use in the estimation */ Collection<? extends DataType> data; /** * Creates a new instance of DistributionEstimationTask * @param distribution * Distribution to estimate * @param data * Data to use in the estimation */ public DistributionEstimationTask( ClosedFormComputableDistribution<DataType> distribution, Collection<? extends DataType> data) { this.distribution = distribution; this.data = data; } @SuppressWarnings("unchecked") public Pair<Double,ClosedFormComputableDistribution<DataType>> call() throws Exception { try { ParallelNegativeLogLikelihood<DataType> costFunction = new ParallelNegativeLogLikelihood<DataType>(this.data); // final int N = this.data.size(); //// final double tolerance = LineBracketInterpolatorBrent.DEFAULT_TOLERANCE / N; // final double tolerance = 1e-100; // LineBracketInterpolatorBrent brent = new LineBracketInterpolatorBrent(); // brent.setTolerance(tolerance); // brent.getGoldenInterpolator().setTolerance(tolerance); // brent.getParabolicInterpolator().setTolerance(tolerance); // LineMinimizerDerivativeFree liner = new LineMinimizerDerivativeFree( brent ); // liner.setTolerance(tolerance); // FunctionMinimizerDirectionSetPowell minimizer = //// new FunctionMinimizerDirectionSetPowell(); // new FunctionMinimizerDirectionSetPowell( liner ); // minimizer.setTolerance(tolerance); // See if the initial parameterization is "in the ballpark" ClosedFormComputableDistribution<DataType> result1 = ObjectUtil.cloneSafe( this.distribution ); double cost1 = costFunction.evaluate(result1); // System.out.println( "Initial Cost: " + cost1 + ", Class: " + result1.getClass().getCanonicalName() + ", Parameters: " + result1.convertToVector() ); // The initial parameters don't work, so guess some more if( Double.isInfinite(cost1) || Double.isNaN(cost1) ) { ClosedFormComputableDistribution<DataType> result2 = ObjectUtil.cloneSafe( this.distribution ); boolean bruteForce = true; int Nsub = Math.min( 1000, this.data.size()/1000 ); // int Nsub = (int) Math.ceil( this.data.size() / 1000 ); Collection<? extends DataType> subList = CollectionUtil.asArrayList(this.data).subList(0, Nsub); // We've got a closed-form estimator... use that next if( this.distribution instanceof EstimableDistribution ) { DistributionEstimator<DataType,ClosedFormComputableDistribution<DataType>> solver = ((EstimableDistribution) this.distribution).getEstimator(); try { result2 = solver.learn( this.data ); double cost2 = costFunction.evaluate(result2); // System.out.println( "Solver Cost: " + cost2 + ", Class: " + result2.getClass().getCanonicalName() + ", Parameters: " + result2.convertToVector() ); bruteForce = (Double.isInfinite(cost2) || Double.isNaN(cost2)); } catch (Exception e) { // System.out.println( "Solver barfed: " + solver.getClass().getCanonicalName() + ", Exception: " + e ); bruteForce = true; result2 = ObjectUtil.cloneSafe(this.distribution); } if( bruteForce ) { try { result2 = solver.learn( subList ); double cost2 = costFunction.evaluate(result2); // System.out.println( "Sub-Solver Cost: " + cost2 + ", Class: " + result2.getClass().getCanonicalName() + ", Parameters: " + result2.convertToVector() ); bruteForce = (Double.isInfinite(cost2) || Double.isNaN(cost2)); } catch (Exception e) { // System.out.println( "Sub-Solver barfed: " + solver.getClass().getCanonicalName() + ", Exception: " + e ); result2 = ObjectUtil.cloneSafe(this.distribution); } } } // Nothing has worked so far, Use Nelder-Mead, which is // slow but is less susceptible to numerical imprecision if( bruteForce ) { FunctionMinimizerNelderMead minimizer1 = new FunctionMinimizerNelderMead(); minimizer1.setMaxIterations(10); minimizer1.setTolerance(1.0); DistributionParameterEstimator<DataType,ClosedFormComputableDistribution<DataType>> estimator2 = new DistributionParameterEstimator<DataType, ClosedFormComputableDistribution<DataType>>( ObjectUtil.cloneSafe(result2), costFunction, minimizer1 ); result2 = estimator2.learn(this.data); double cost2 = costFunction.evaluate(result2); // System.out.println( "Brute Cost: " + cost2 + ", Class: " + result2.getClass().getCanonicalName() + ", Parameters: " + result2.convertToVector() ); // Damn.. nothing has worked so far... subsample the // data and re-estimate. if( Double.isInfinite(cost2) || Double.isNaN(cost2) ) { minimizer1.setMaxIterations(1000); costFunction.setCostParameters(subList); estimator2 = new DistributionParameterEstimator<DataType, ClosedFormComputableDistribution<DataType>>( ObjectUtil.cloneSafe(result2), costFunction, minimizer1 ); result2 = estimator2.learn(subList); costFunction.setCostParameters(this.data); double cost3 = costFunction.evaluate(result2); // System.out.println( "Subsample Cost: " + cost3 + ", Class: " + result2.getClass().getCanonicalName() + ", Parameters: " + result2.convertToVector() ); } } result1 = result2; } FunctionMinimizerDirectionSetPowell minimizer3 = new FunctionMinimizerDirectionSetPowell(); DistributionParameterEstimator<DataType,ClosedFormComputableDistribution<DataType>> estimator3 = new DistributionParameterEstimator<DataType, ClosedFormComputableDistribution<DataType>>( ObjectUtil.cloneSafe(result1), costFunction, minimizer3 ); ClosedFormComputableDistribution<DataType> result3 = estimator3.learn(this.data); double cost3 = costFunction.evaluate(result3); // System.out.println( "Final Cost: " + cost3 + ", Class: " + result3.getClass().getCanonicalName() + ", Parameters: " + result3.convertToVector() ); return DefaultPair.create( cost3, result3 ); } catch (Exception e) { // System.out.println( this.distribution.getClass().getCanonicalName() + " barfed: " + e ); // e.printStackTrace(); return DefaultPair.create( Double.POSITIVE_INFINITY, (ClosedFormComputableDistribution<DataType>) this.distribution.clone() ); } } } /** * Estimates a continuous distribution. * * @param data * The data to estimate a distribution for. * @return * The estimated distribution. * @throws Exception * If there is an error in the estimation. */ public static SmoothUnivariateDistribution estimateContinuousDistribution( Collection<Double> data ) throws Exception { LinkedList<SmoothUnivariateDistribution> distributions = getDistributionClasses( SmoothUnivariateDistribution.class ); MaximumLikelihoodDistributionEstimator<Double> estimator = new MaximumLikelihoodDistributionEstimator<Double>( distributions ); return (SmoothUnivariateDistribution) estimator.learn(data); } /** * Estimates a discrete distribution. * * @param data * The data to estimate a distribution for. * @return * The estimated distribution. * @throws Exception * If there is an error in the estimation. */ @SuppressWarnings(value={"unchecked", "rawtypes"}) public static ClosedFormDiscreteUnivariateDistribution estimateDiscreteDistribution( Collection<? extends Number> data ) throws Exception { LinkedList<ClosedFormDiscreteUnivariateDistribution> distributions = getDistributionClasses( ClosedFormDiscreteUnivariateDistribution.class ); MaximumLikelihoodDistributionEstimator<Number> estimator = new MaximumLikelihoodDistributionEstimator( (Collection<? extends ClosedFormComputableDistribution>) distributions); return (ClosedFormDiscreteUnivariateDistribution) estimator.learn(data); } /** * Gets the distribution classes for the given base distribution. * * @param <DistributionType> * The type of distribution. * @param baseDistribution * The class of the base distribution. * @return * The list of implementations of that distribution in the statistics * distribution package. * @throws ClassNotFoundException * @throws IOException * @throws InstantiationException * @throws IllegalAccessException */ @SuppressWarnings("unchecked") protected static <DistributionType extends ClosedFormComputableDistribution<?>> LinkedList<DistributionType> getDistributionClasses( Class<? extends DistributionType> baseDistribution ) throws ClassNotFoundException, IOException, InstantiationException, IllegalAccessException { UnivariateGaussian g = new UnivariateGaussian(); Package p = g.getClass().getPackage(); LinkedList<Class<?>> cs = getClasses( p.getName() ); LinkedList<DistributionType> instances = new LinkedList<DistributionType>(); for( Class<?> c : cs ) { if( baseDistribution.isAssignableFrom( c ) ) { if( ProbabilityFunction.class.isAssignableFrom(c) ) { try { instances.add( (DistributionType) c.newInstance()); } catch (Exception e) { // System.out.println( "Couldn't instantiate: " + c.getCanonicalName() ); } } } } return instances; } /** * Scans all classes accessible from the context class loader which belong to the given package and subpackages. * * @param packageName The base package * @return The classes * @throws ClassNotFoundException * @throws IOException */ private static LinkedList<Class<?>> getClasses( String packageName) throws ClassNotFoundException, IOException { ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); assert classLoader != null; String path = packageName.replace('.', '/'); Enumeration<URL> resources = classLoader.getResources(path); List<File> dirs = new ArrayList<File>(); while (resources.hasMoreElements()) { URL resource = resources.nextElement(); dirs.add(new File(resource.getFile())); } LinkedList<Class<?>> classes = new LinkedList<Class<?>>(); for (File directory : dirs) { classes.addAll(findClasses(directory, packageName)); } return classes; } /** * Recursive method used to find all classes in a given directory and subdirs. * * @param directory The base directory * @param packageName The package name for classes found inside the base directory * @return The classes * @throws ClassNotFoundException */ private static LinkedList<Class<?>> findClasses( File directory, String packageName) throws ClassNotFoundException { LinkedList<Class<?>> classes = new LinkedList<Class<?>>(); File[] files = directory.listFiles(); for (File file : files) { if (file.getName().endsWith(".class")) { classes.add(Class.forName(packageName + '.' + file.getName().substring(0, file.getName().length() - 6))); } } return classes; } }