/* * File: LearnerComparisonExperiment.java * Authors: Justin Basilico * Company: Sandia National Laboratories * Project: Cognitive Framework Lite * * Copyright October 1, 2007, Sandia Corporation. Under the terms of Contract * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by * or on behalf of the U.S. Government. Export of this program may require a * license from the United States Government. See CopyrightHistory.txt for * complete details. * * */ package gov.sandia.cognition.learning.experiment; import gov.sandia.cognition.learning.algorithm.BatchLearner; import gov.sandia.cognition.learning.performance.PerformanceEvaluator; import gov.sandia.cognition.util.Summarizer; import gov.sandia.cognition.learning.data.PartitionedDataset; import gov.sandia.cognition.statistics.method.ConfidenceStatistic; import gov.sandia.cognition.statistics.method.NullHypothesisEvaluator; import gov.sandia.cognition.util.DefaultPair; import gov.sandia.cognition.util.Pair; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; /** * The {@code LearnerComparisonExperiment} compares the performance of two * machine learning algorithms to determine (using a statistical test) if the * two algorithms have significantly different performance. * *@param <InputDataType> * The type of the data to perform the experiment with. * This will be passed to the fold creator to create a number of folds * on which to validate the performance of the learning algorithms. * @param <FoldDataType> * The type of data created by the fold creator that will go into * the learning algorithms. Typically, this is the same as the * InputDataType, but it does not need to be. It just needs to match * the output of the fold creator and the input of the learning * algorithms. * @param <LearnedType> The type of the output produced by the learning * algorithms whose performance will be evaluated on each fold of data. * @param <StatisticType> The type of the statistic generated by the * performance evaluator on each learned objects for each fold. It is * created by passing the learned object plus the test data for the * fold into the performance evaluator. * @param <SummaryType> The type produced by the summarizer at the end of * the experiment from a collection of the given statistics (one for * each fold). This represents the performance result for the * comparison of the performance of the learning algorithms for the * whole experiment. * @author Justin Basilico * @since 2.0 */ public class LearnerComparisonExperiment<InputDataType, FoldDataType, LearnedType, StatisticType, SummaryType> extends AbstractValidationFoldExperiment<InputDataType, FoldDataType> implements Serializable { /** The evaluator to use to compute the performance of the learned object on * each fold. */ protected PerformanceEvaluator<? super LearnedType, ? super Collection<? extends FoldDataType>, ? extends StatisticType> performanceEvaluator; /** The summarizer for summarizing the result of the performance evaluator * from all the folds. */ protected Summarizer<? super StatisticType, ? extends SummaryType> summarizer; /** The statistical test to use to determine if the two learners are * significantly different. */ private NullHypothesisEvaluator<Collection<? extends StatisticType>> statisticalTest; /** The learners that the experiment is being performed on. */ protected Pair<BatchLearner<? super Collection<? extends FoldDataType>, ? extends LearnedType>, BatchLearner<? super Collection<? extends FoldDataType>, ? extends LearnedType>> learners; /** The performance evaluations made during the experiment. */ protected DefaultPair<ArrayList<StatisticType>, ArrayList<StatisticType>> statistics; /** The confidence statistic generated from the underlying performance * statistics. */ protected ConfidenceStatistic confidence; /** The summaries of performance. */ protected DefaultPair<SummaryType, SummaryType> summaries; /** * Creates a new instance of LearnerComparisonExperiment. */ public LearnerComparisonExperiment() { this(null, null, null, null); } /** * Creates a new instance of LearnerComparisonExperiment. * * @param foldCreator The object to use for creating the folds. * @param performanceEvaluator The evaluator to use to compute the * performance of the learned object on each fold. * @param statisticalTest The statistical test to apply to the performance * results of the two learners to determine if they are * statistically different. * @param summarizer The summarizer for summarizing the result of the * performance evaluator from all the folds. */ public LearnerComparisonExperiment( final ValidationFoldCreator<InputDataType, FoldDataType> foldCreator, final PerformanceEvaluator<? super LearnedType, ? super Collection<? extends FoldDataType>, ? extends StatisticType> performanceEvaluator, final NullHypothesisEvaluator<Collection<? extends StatisticType>> statisticalTest, final Summarizer<? super StatisticType, ? extends SummaryType> summarizer) { super(foldCreator); this.setPerformanceEvaluator(performanceEvaluator); this.setStatisticalTest(statisticalTest); this.setSummarizer(summarizer); // The initial number of trials is unknown. this.setStatistics(null); this.setConfidence(null); this.setSummaries(null); } /** * Evaluates the two batch learners using the given data on the same set of * validation folds and returns the resulting information including the * confidence statistic that the two are different along with the summary * of their performance. * * @param learners The two learners. * @param data The data to use. * @return The experimental results. */ public Result<SummaryType> evaluate( final Pair<BatchLearner<? super Collection<? extends FoldDataType>, ? extends LearnedType>, BatchLearner<? super Collection<? extends FoldDataType>, ? extends LearnedType>> learners, final Collection<? extends InputDataType> data) { // The first step in the experiment is to create the folds. final Collection<PartitionedDataset<FoldDataType>> folds = this.getFoldCreator().createFolds(data); this.setLearners(learners); // Initialize the collection where we will store the statistics // generated from the data. this.setStatistics( new DefaultPair<ArrayList<StatisticType>, ArrayList<StatisticType>>( new ArrayList<StatisticType>(folds.size()), new ArrayList<StatisticType>(folds.size()))); // Run the experiment. this.runExperiment(folds); // The confidence. this.setConfidence(this.getStatisticalTest().evaluateNullHypothesis( this.getStatistics().getFirst(), this.getStatistics().getSecond())); // Summarize the statistics. final SummaryType summary1 = this.getSummarizer().summarize( this.getStatistics().getFirst()); final SummaryType summary2 = this.getSummarizer().summarize( this.getStatistics().getSecond()); this.setSummaries( new DefaultPair<SummaryType, SummaryType>(summary1, summary2)); return new Result<SummaryType>( this.getConfidence(), this.getSummaries()); } /** * {@inheritDoc} * * @param fold {@inheritDoc} */ protected void runTrial( final PartitionedDataset<FoldDataType> fold) { // Perform the learning algorithm on this fold for the first learner. final LearnedType learned1 = this.getLearners().getFirst().learn(fold.getTrainingSet()); // Compute the statistic and add it to the collection for the first // learner. final StatisticType statistic1 = this.getPerformanceEvaluator().evaluatePerformance( learned1, fold.getTestingSet()); this.getStatistics().getFirst().add(statistic1); // Perform the learning algorithm on this fold for the second learner. final LearnedType learned2 = this.getLearners().getSecond().learn(fold.getTrainingSet()); // Compute the statistic and add it to the collection for the second // learner. final StatisticType statistic2 = this.getPerformanceEvaluator().evaluatePerformance( learned2, fold.getTestingSet()); this.getStatistics().getSecond().add(statistic2); } /** * Evaluates the two batch learners using the given data on the same set of * validation folds and returns the resulting information including the * confidence statistic that the two are different along with the summary * of their performance. * * @param learner1 The first learner. * @param learner2 The second learner. * @param data The data to use. * @return The experimental results. */ public Result<SummaryType> evaluate( final BatchLearner<? super Collection<? extends FoldDataType>, ? extends LearnedType> learner1, final BatchLearner<? super Collection<? extends FoldDataType>, ? extends LearnedType> learner2, final Collection<? extends InputDataType> data) { return this.evaluate( new DefaultPair<BatchLearner<? super Collection<? extends FoldDataType>, ? extends LearnedType>, BatchLearner<? super Collection<? extends FoldDataType>, ? extends LearnedType>>( learner1, learner2), data); } /** * Gets the performance evaluator to apply to each fold. * * @return The performance evaluator to apply to each fold. */ public PerformanceEvaluator<? super LearnedType, ? super Collection<? extends FoldDataType>, ? extends StatisticType> getPerformanceEvaluator() { return this.performanceEvaluator; } /** * Sets the performance evaluator to apply to each fold. * * @param performanceEvaluator * The performance evaluator to apply to each fold. */ public void setPerformanceEvaluator( final PerformanceEvaluator<? super LearnedType, ? super Collection<? extends FoldDataType>, ? extends StatisticType> performanceEvaluator) { this.performanceEvaluator = performanceEvaluator; } /** * Gets the summarizer of the performance evaluations. * * @return The summarizer of the performance evaluations. */ public Summarizer<? super StatisticType, ? extends SummaryType> getSummarizer() { return this.summarizer; } /** * Sets the summarizer of the performance evaluations. * * @param summarizer The summarizer of the performance evaluations. */ public void setSummarizer( final Summarizer<? super StatisticType, ? extends SummaryType> summarizer) { this.summarizer = summarizer; } /** * Gets the statistical test to use to determine if the two learners are * significantly different. * * @return The statistical test. */ public NullHypothesisEvaluator<Collection<? extends StatisticType>> getStatisticalTest() { return this.statisticalTest; } /** * Sets the statistical test to use to determine if the two learners are * significantly different. * * @param statisticalTest The statistical test. */ public void setStatisticalTest( final NullHypothesisEvaluator<Collection<? extends StatisticType>> statisticalTest) { this.statisticalTest = statisticalTest; } /** * Gets the learners the experiment is being run on. * * @return The learners. */ public Pair<BatchLearner<? super Collection<? extends FoldDataType>, ? extends LearnedType>, BatchLearner<? super Collection<? extends FoldDataType>, ? extends LearnedType>> getLearners() { return this.learners; } /** * Sets the learners the experiment is being run on. * * @param learners The learners. */ protected void setLearners( final Pair<BatchLearner<? super Collection<? extends FoldDataType>, ? extends LearnedType>, BatchLearner<? super Collection<? extends FoldDataType>, ? extends LearnedType>> learners) { this.learners = learners; } /** * Gets the performance evaluations for the trials of the experiment. * * @return The performance evaluations for the trials of the experiment. */ public DefaultPair<ArrayList<StatisticType>, ArrayList<StatisticType>> getStatistics() { return this.statistics; } /** * Sets the performance evaluations for the trials of the experiment. * * @param statistics * The performance evaluations for the trials of the experiment. */ protected void setStatistics( final DefaultPair<ArrayList<StatisticType>, ArrayList<StatisticType>> statistics) { this.statistics = statistics; } /** * Gets the confidence statistic that the two learners are different. * * @return The confidence statistic of the experiment. */ public ConfidenceStatistic getConfidence() { return this.confidence; } /** * Sets the confidence statistic that the two learners are different. * * @param confidence The confidence statistic of the experiment. */ protected void setConfidence( final ConfidenceStatistic confidence) { this.confidence = confidence; } /** * Gets the summaries of the experiment. * * @return The summaries of the experiment. */ public DefaultPair<SummaryType, SummaryType> getSummaries() { return this.summaries; } /** * Sets the summaries of the experiment. * * @param summaries The summaries of the experiment. */ protected void setSummaries( final DefaultPair<SummaryType, SummaryType> summaries) { this.summaries = summaries; } /** * Encapsulates the results of the comparison experiment. * * @param <SummaryType> The type produced by the summarizer at the end of * the experiment from a collection of the given statistics (one * for each fold). This represents the performance result for the * comparison of the performance of the learning algorithms for the * whole experiment. */ public static class Result<SummaryType> { /** The confidence statistic for the learners. */ private ConfidenceStatistic confidence; /** The summary of performance for the learners. */ private DefaultPair<SummaryType, SummaryType> summaries; /** * Creates a new instance of Result. * * @param confidence The confidence statistic for the learners. * @param summaries The summary of performance for the learners. */ public Result( final ConfidenceStatistic confidence, final DefaultPair<SummaryType, SummaryType> summaries) { super(); this.setConfidence(confidence); this.setSummaries(summaries); } /** * Gets the confidence statistic for the learners. * * @return The confidence statistic for the learners. */ public ConfidenceStatistic getConfidence() { return this.confidence; } /** * Sets the confidence statistic for the learners. * * @param confidence The confidence statistic for the learners. */ public void setConfidence( final ConfidenceStatistic confidence) { this.confidence = confidence; } /** * Gets the summary of performance for the learners. * * @return The summary of performance for the learners. */ public DefaultPair<SummaryType, SummaryType> getSummaries() { return this.summaries; } /** * Sets the summary of performance for the learners. * * @param summaries The summary of performance for the learners. */ public void setSummaries( final DefaultPair<SummaryType, SummaryType> summaries) { this.summaries = summaries; } } }