/* * File: LearnerValidationExperiment.java * Authors: Justin Basilico * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright September 21, 2007, Sandia Corporation. Under the terms of Contract * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by * or on behalf of the U.S. Government. Export of this program may require a * license from the United States Government. See CopyrightHistory.txt for * complete details. * */ package gov.sandia.cognition.learning.experiment; import gov.sandia.cognition.annotation.PublicationReference; import gov.sandia.cognition.annotation.PublicationType; import gov.sandia.cognition.learning.algorithm.BatchLearner; import gov.sandia.cognition.learning.performance.PerformanceEvaluator; import gov.sandia.cognition.learning.data.PartitionedDataset; import gov.sandia.cognition.util.Summarizer; import java.util.ArrayList; import java.util.Collection; /** * The {@code LearnerValidationExperiment} class implements an experiment where * a supervised machine learning algorithm is evaluated by applying it to a set * of folds created from a given set of data. * * @param <InputDataType> * The type of the data to perform the experiment with. * This will be passed to the fold creator to create a number of folds * on which to validate the performance of the learning algorithm. * @param <FoldDataType> * The type of data created by the fold creator that will go into * the learning algorithm. Typically, this is the same as the * InputDataType, but it does not need to be. It just needs to match * the output of the fold creator and the input of the learning * algorithm. * @param <LearnedType> The type of the output produced by the learning * algorithm whose performance will be evaluated on each fold of data. * @param <StatisticType> The type of the statistic generated by the * performance evaluator on the learned object for each fold. It is * created by passing the learned object plus the test data for the * fold into the performance evaluator. * @param <SummaryType> The type produced by the summarizer at the end of * the experiment from a collection of the given statistics (one for * each fold). This represents the performance result for the learning * algorithm for the whole experiment. * @author Justin Basilico * @since 2.0 */ @PublicationReference( author="Wikipedia", title="Decriptive statistics", type=PublicationType.WebPage, year=2008, url="http://en.wikipedia.org/wiki/Descriptive_statistics" ) public class LearnerValidationExperiment <InputDataType, FoldDataType, LearnedType, StatisticType, SummaryType> extends AbstractValidationFoldExperiment<InputDataType, FoldDataType> implements PerformanceEvaluator<BatchLearner <? super Collection<? extends FoldDataType>, ? extends LearnedType>, Collection<? extends InputDataType>, SummaryType> { /** The evaluator to use to compute the performance of the learned object on * each fold. */ protected PerformanceEvaluator <? super LearnedType, ? super Collection<? extends FoldDataType>, ? extends StatisticType> performanceEvaluator; /** The summarizer for summarizing the result of the performance evaluator * from all the folds. */ protected Summarizer<? super StatisticType, ? extends SummaryType> summarizer; /** The learner that the experiment is run on. */ private BatchLearner <? super Collection<? extends FoldDataType>, ? extends LearnedType> learner; /** The performance evaluations made during the experiment. */ protected ArrayList<StatisticType> statistics; /** The summary of the performance evaluations made at the end of the * experiment. */ protected SummaryType summary; /** * Creates a new instance of SupervisedLearnerExperiment. */ public LearnerValidationExperiment() { this(null, null, null); } /** * Creates a new instance of SupervisedLearnerExperiment. * * @param foldCreator The object to use for creating the folds. * @param performanceEvaluator The evaluator to use to compute the * performance of the learned object on each fold. * @param summarizer The summarizer for summarizing the result of the * performance evaluator from all the folds. */ public LearnerValidationExperiment( final ValidationFoldCreator<InputDataType, FoldDataType> foldCreator, final PerformanceEvaluator <? super LearnedType, ? super Collection<? extends FoldDataType>, ? extends StatisticType> performanceEvaluator, final Summarizer<? super StatisticType, ? extends SummaryType> summarizer) { super(foldCreator); this.setPerformanceEvaluator(performanceEvaluator); this.setSummarizer(summarizer); // The initial number of trials is unknown. this.setStatistics(null); this.setSummary(null); } /** * @deprecated Use evaluatePerformance instead. * * Performs the experiment. * * @param data The data to use. * @param learner The learner to perform the experiment on. * @return The summary of the experiment. */ @Deprecated public SummaryType evaluate( final BatchLearner<? super Collection<? extends FoldDataType>, ? extends LearnedType> learner, final Collection<? extends InputDataType> data) { return this.evaluatePerformance(learner, data); } public SummaryType evaluatePerformance( final BatchLearner<? super Collection<? extends FoldDataType>, ? extends LearnedType> learner, final Collection<? extends InputDataType> data) { // The first step in the experiment is to create the folds. final Collection<PartitionedDataset<FoldDataType>> folds = this.getFoldCreator().createFolds(data); this.setLearner(learner); // Initialize the collection where we will store the statistics // generated from the data. this.setStatistics(new ArrayList<StatisticType>(folds.size())); this.setSummary(null); this.runExperiment(folds); // Summarize the statistics. this.setSummary(this.getSummarizer().summarize(this.getStatistics())); return this.getSummary(); } protected void runTrial( final PartitionedDataset<FoldDataType> fold) { // Perform the learning algorithm on this fold. final LearnedType learned = getLearner().learn(fold.getTrainingSet()); // Compute the statistic of the learned object on the testing set. final Collection<FoldDataType> testingSet = fold.getTestingSet(); final StatisticType statistic = this.getPerformanceEvaluator().evaluatePerformance( learned, testingSet); statistics.add(statistic); } /** * Gets the performance evaluator to apply to each fold. * * @return The performance evaluator to apply to each fold. */ public PerformanceEvaluator <? super LearnedType, ? super Collection<? extends FoldDataType>, ? extends StatisticType> getPerformanceEvaluator() { return this.performanceEvaluator; } /** * Sets the performance evaluator to apply to each fold. * * @param performanceEvaluator * The performance evaluator to apply to each fold. */ public void setPerformanceEvaluator( final PerformanceEvaluator <? super LearnedType, ? super Collection<? extends FoldDataType>, ? extends StatisticType> performanceEvaluator) { this.performanceEvaluator = performanceEvaluator; } /** * Gets the summarizer of the performance evaluations. * * @return The summarizer of the performance evaluations. */ public Summarizer<? super StatisticType, ? extends SummaryType> getSummarizer() { return this.summarizer; } /** * Sets the summarizer of the performance evaluations. * * @param summarizer The summarizer of the performance evaluations. */ public void setSummarizer( final Summarizer<? super StatisticType, ? extends SummaryType> summarizer) { this.summarizer = summarizer; } /** * Gets the learner the experiment is being run on. * * @return The learner. */ public BatchLearner <? super Collection<? extends FoldDataType>, ? extends LearnedType> getLearner() { return this.learner; } /** * Sets the learner the experiment is being run on. * * @param learner The learner. */ protected void setLearner( final BatchLearner <? super Collection<? extends FoldDataType>, ? extends LearnedType> learner) { this.learner = learner; } /** * Gets the performance evaluations for the trials of the experiment. * * @return The performance evaluations for the trials of the experiment. */ public ArrayList<StatisticType> getStatistics() { return this.statistics; } /** * Sets the performance evaluations for the trials of the experiment. * * @param statistics * The performance evaluations for the trials of the experiment. */ protected void setStatistics( final ArrayList<StatisticType> statistics) { this.statistics = statistics; } /** * Gets the summary of the experiment. * * @return The summary of the experiment. */ public SummaryType getSummary() { return this.summary; } /** * Sets the summary of the experiment. * * @param summary The summary of the experiment. */ protected void setSummary( final SummaryType summary) { this.summary = summary; } }