/* * File: IncrementalLearnerValidationExperiment.java * Authors: Justin Basilico * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright June 10, 2010, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. Export * of this program may require a license from the United States Government. * See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.learning.experiment; import gov.sandia.cognition.learning.algorithm.IncrementalLearner; import gov.sandia.cognition.learning.performance.PerformanceEvaluator; import gov.sandia.cognition.util.Summarizer; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; /** * Implements an experiment where an incremental supervised machine learning * algorithm is evaluated by applying it to a set of data by successively * testing on each item and then training on it. * * @param <DataType> * The type of the data to perform the experiment with. * It will be used as input into the learning algorithm. * @param <LearnedType> The type of the output produced by the learning * algorithm whose performance will be evaluated on each data item. * @param <StatisticType> The type of the statistic generated by the * performance evaluator on the learned object for each data item. It * is created by passing the learned object plus the data item * into the performance evaluator. * @param <SummaryType> The type produced by the summarizer at the end of * the experiment from a collection of the given statistics (one for * each item). This represents the performance result for the learning * algorithm for the whole experiment. * @author Justin Basilico * @since 3.0 */ public class OnlineLearnerValidationExperiment<DataType, LearnedType, StatisticType, SummaryType> extends AbstractLearningExperiment implements PerformanceEvaluator<IncrementalLearner<? super DataType, LearnedType>, Collection<? extends DataType>, SummaryType>, Serializable // TODO: This class is largely copied from LearnerValidationExperiment. // They should probably be merged into abstract classes. // --jdbasil (2010-06-10) { /** The evaluator to use to compute the performance of the learned object on * each fold. */ protected PerformanceEvaluator <? super LearnedType, ? super Collection<? extends DataType>, ? extends StatisticType> performanceEvaluator; /** The summarizer for summarizing the result of the performance evaluator * from all the folds. */ protected Summarizer<? super StatisticType, ? extends SummaryType> summarizer; /** The number of trials in the experiment, which is the number of folds * in the experiment. */ protected int numTrials; /** The performance evaluations made during the experiment. */ protected ArrayList<StatisticType> statistics; /** The summary of the performance evaluations made at the end of the * experiment. */ protected SummaryType summary; /** * Creates a new instance of IncrementalLearnerValidationExperiment. */ public OnlineLearnerValidationExperiment() { this(null, null); } /** * Creates a new instance of IncrementalLearnerValidationExperiment. * * @param performanceEvaluator The evaluator to use to compute the * performance of the learned object on each fold. * @param summarizer The summarizer for summarizing the result of the * performance evaluator from all the folds. */ public OnlineLearnerValidationExperiment( final PerformanceEvaluator <? super LearnedType, ? super Collection<? extends DataType>, ? extends StatisticType> performanceEvaluator, final Summarizer<? super StatisticType, ? extends SummaryType> summarizer) { super(); this.setPerformanceEvaluator(performanceEvaluator); this.setSummarizer(summarizer); // The initial number of trials is unknown. this.setNumTrials(-1); this.setStatistics(null); this.setSummary(null); } /** * Performs the experiment. * * @param data The data to use. * @param learner The learner to perform the experiment on. * @return The summary of the experiment. */ public SummaryType evaluatePerformance( final IncrementalLearner<? super DataType, LearnedType> learner, final Collection<? extends DataType> data) { // Initialize the collection where we will store the statistics // generated from the data. this.setNumTrials(data.size()); this.setStatistics(new ArrayList<StatisticType>(data.size())); this.setSummary(null); // We've started the experiment. this.fireExperimentStarted(); // Initialize learning. final LearnedType learned = learner.createInitialLearnedObject(); // Go through and evaluate each item and then update the model using // it. for (DataType item : data) { // Start a new trial. this.fireTrialStarted(); // Compute the statistic for this item. final StatisticType statistic = this.getPerformanceEvaluator().evaluatePerformance( learned, Collections.singletonList(item)); this.statistics.add(statistic); // Update the learned value. learner.update(learned, item); // The trial has ended. this.fireTrialEnded(); } // Summarize the statistics. this.setSummary(this.getSummarizer().summarize(this.getStatistics())); // The experiment has ended. this.fireExperimentEnded(); // The result is the summary statistic. return this.getSummary(); } /** * Gets the performance evaluator to apply to each fold. * * @return The performance evaluator to apply to each fold. */ public PerformanceEvaluator <? super LearnedType, ? super Collection<? extends DataType>, ? extends StatisticType> getPerformanceEvaluator() { return this.performanceEvaluator; } /** * Sets the performance evaluator to apply to each fold. * * @param performanceEvaluator * The performance evaluator to apply to each fold. */ public void setPerformanceEvaluator( final PerformanceEvaluator <? super LearnedType, ? super Collection<? extends DataType>, ? extends StatisticType> performanceEvaluator) { this.performanceEvaluator = performanceEvaluator; } /** * Gets the summarizer of the performance evaluations. * * @return The summarizer of the performance evaluations. */ public Summarizer<? super StatisticType, ? extends SummaryType> getSummarizer() { return this.summarizer; } /** * Sets the summarizer of the performance evaluations. * * @param summarizer The summarizer of the performance evaluations. */ public void setSummarizer( final Summarizer<? super StatisticType, ? extends SummaryType> summarizer) { this.summarizer = summarizer; } /** * Gets the number of trials. Will be equal to the number of data points * that the experiment is being run over. * * @return * The number of trials. */ public int getNumTrials() { return this.numTrials; } /** * Sets the number of trials. * * @param numTrials * The number of trials. */ protected void setNumTrials( final int numTrials) { this.numTrials = numTrials; } /** * Gets the performance evaluations for the trials of the experiment. * * @return The performance evaluations for the trials of the experiment. */ public ArrayList<StatisticType> getStatistics() { return this.statistics; } /** * Sets the performance evaluations for the trials of the experiment. * * @param statistics * The performance evaluations for the trials of the experiment. */ protected void setStatistics( final ArrayList<StatisticType> statistics) { this.statistics = statistics; } /** * Gets the summary of the experiment. * * @return The summary of the experiment. */ public SummaryType getSummary() { return this.summary; } /** * Sets the summary of the experiment. * * @param summary The summary of the experiment. */ protected void setSummary( final SummaryType summary) { this.summary = summary; } }