/* * File: ParallelLearnerValidationExperiment.java * Authors: Justin Basilico * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright October 04, 2008, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. Export * of this program may require a license from the United States Government. * See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.learning.experiment; import gov.sandia.cognition.algorithm.ParallelAlgorithm; import gov.sandia.cognition.algorithm.ParallelUtil; import gov.sandia.cognition.learning.algorithm.BatchLearner; import gov.sandia.cognition.learning.data.PartitionedDataset; import gov.sandia.cognition.learning.performance.PerformanceEvaluator; import gov.sandia.cognition.util.ObjectUtil; import gov.sandia.cognition.util.Summarizer; import java.util.Collection; import java.util.LinkedList; import java.util.concurrent.Callable; import java.util.concurrent.ThreadPoolExecutor; import java.util.logging.Level; import java.util.logging.Logger; /** * Parallel version of the LearnerValidationExperiment class that executes * the validations experiments across available cores and hyperthreads. * * @param <InputDataType> * The type of the data to perform the experiment with. * This will be passed to the fold creator to create a number of folds * on which to validate the performance of the learning algorithm. * @param <FoldDataType> * The type of data created by the fold creator that will go into * the learning algorithm. Typically, this is the same as the * InputDataType, but it does not need to be. It just needs to match * the output of the fold creator and the input of the learning * algorithm. * @param <LearnedType> The type of the output produced by the learning * algorithm whose performance will be evaluated on each fold of data. * @param <StatisticType> The type of the statistic generated by the * performance evaluator on the learned object for each fold. It is * created by passing the learned object plus the test data for the * fold into the performance evaluator. * @param <SummaryType> The type produced by the summarizer at the end of * the experiment from a collection of the given statistics (one for * each fold). This represents the performance result for the learning * algorithm for the whole experiment. * @author Justin Basilico * @since 3.0 */ public class ParallelLearnerValidationExperiment<InputDataType, FoldDataType, LearnedType, StatisticType, SummaryType> extends LearnerValidationExperiment<InputDataType, FoldDataType, LearnedType, StatisticType, SummaryType> implements ParallelAlgorithm { /** * Thread pool used to split the computation across multiple cores */ private transient ThreadPoolExecutor threadPool; /** * Default constructor */ public ParallelLearnerValidationExperiment() { this(null, null, null); } /** * Creates a new instance of ParallelLearnerValidationExperiment. * * @param foldCreator The object to use for creating the folds. * @param performanceEvaluator The evaluator to use to compute the * performance of the learned object on each fold. * @param summarizer The summarizer for summarizing the result of the * performance evaluator from all the folds. */ public ParallelLearnerValidationExperiment( final ValidationFoldCreator<InputDataType, FoldDataType> foldCreator, final PerformanceEvaluator <? super LearnedType, ? super Collection<? extends FoldDataType>, ? extends StatisticType> performanceEvaluator, final Summarizer<? super StatisticType, ? extends SummaryType> summarizer) { super(foldCreator, performanceEvaluator, summarizer); } @Override protected void runExperiment( final Collection<PartitionedDataset<FoldDataType>> folds) { // The number of trials is the number of folds. this.setNumTrials(folds.size()); this.fireExperimentStarted(); LinkedList<Callable<StatisticType>> trials = new LinkedList<Callable<StatisticType>>(); // Go through the folds and run the trial for each fold. for (PartitionedDataset<FoldDataType> fold : folds) { final TrialTask trial = new TrialTask(fold); trials.add(trial); } Collection<StatisticType> results = null; try { results = ParallelUtil.executeInParallel( trials, this.getThreadPool() ); } catch (Exception ex) { Logger.getLogger(ParallelLearnerValidationExperiment.class.getName()).log(Level.SEVERE, null, ex); } this.getStatistics().addAll( results ); this.fireExperimentEnded(); } public ThreadPoolExecutor getThreadPool() { if (this.threadPool == null) { this.threadPool = ParallelUtil.createThreadPool(); } return this.threadPool; } public void setThreadPool( final ThreadPoolExecutor threadPool) { this.threadPool = threadPool; } public int getNumThreads() { return ParallelUtil.getNumThreads( this ); } /** * Callable task for a single evaluation trial */ private class TrialTask extends Object implements Callable<StatisticType> { /** * Dataset partition */ private PartitionedDataset<FoldDataType> fold; /** * Creates a new instance of TrialTask * @param fold * Dataset partition */ public TrialTask( final PartitionedDataset<FoldDataType> fold) { super(); this.fold = fold; } @Override public StatisticType call() throws Exception { fireTrialStarted(); // Perform the learning algorithm on this fold. final BatchLearner<? super Collection<? extends FoldDataType>, ? extends LearnedType> learnerClone = ObjectUtil.cloneSmart(getLearner()); final LearnedType learned = learnerClone.learn(fold.getTrainingSet()); // Compute the statistic of the learned object on the testing set. final Collection<FoldDataType> testingSet = fold.getTestingSet(); final StatisticType statistic = getPerformanceEvaluator().evaluatePerformance( learned, testingSet); fireTrialEnded(); return statistic; } } }