/* * File: ParallelizedCostFunctionContainer.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright Sep 22, 2008, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.learning.function.cost; import gov.sandia.cognition.algorithm.ParallelAlgorithm; import gov.sandia.cognition.algorithm.ParallelUtil; import gov.sandia.cognition.evaluator.Evaluator; import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable; import gov.sandia.cognition.learning.data.InputOutputPair; import gov.sandia.cognition.learning.data.SequentialDataMultiPartitioner; import gov.sandia.cognition.learning.data.TargetEstimatePair; import gov.sandia.cognition.math.matrix.Vector; import gov.sandia.cognition.util.ObjectUtil; import java.util.ArrayList; import java.util.Collection; import java.util.concurrent.Callable; import java.util.concurrent.ThreadPoolExecutor; import java.util.logging.Level; import java.util.logging.Logger; /** * A cost function that automatically splits a ParallelizableCostFunction * across multiple cores/processors to speed up computation. * @author Kevin R. Dixon * @since 2.1 */ public class ParallelizedCostFunctionContainer extends AbstractSupervisedCostFunction<Vector,Vector> implements DifferentiableCostFunction, ParallelAlgorithm { /** * Cost function to parallelize */ private ParallelizableCostFunction costFunction; /** * Collection of evaluation thread calls */ private transient ArrayList<Callable<Object>> evaluationComponents; /** * Collection of evaluation gradient calls */ private transient ArrayList<Callable<Object>> gradientComponents; /** * Thread pool used to parallelize the computation */ private transient ThreadPoolExecutor threadPool; /** * Default constructor for ParallelizedCostFunctionContainer. */ public ParallelizedCostFunctionContainer() { this( (ParallelizableCostFunction) null ); } /** * Creates a new instance of ParallelizedCostFunctionContainer * @param costFunction * Cost function to parallelize */ public ParallelizedCostFunctionContainer( ParallelizableCostFunction costFunction ) { this( costFunction, ParallelUtil.createThreadPool() ); } /** * Creates a new instance of ParallelizedCostFunctionContainer * @param threadPool * Thread pool used to parallelize the computation * @param costFunction * Cost function to parallelize */ public ParallelizedCostFunctionContainer( ParallelizableCostFunction costFunction, ThreadPoolExecutor threadPool ) { this.setCostFunction( costFunction ); this.setThreadPool( threadPool ); } @Override public ParallelizedCostFunctionContainer clone() { ParallelizedCostFunctionContainer clone = (ParallelizedCostFunctionContainer) super.clone(); clone.setCostFunction( ObjectUtil.cloneSafe( this.getCostFunction() ) ); clone.setThreadPool( ParallelUtil.createThreadPool( this.getNumThreads() ) ); return clone; } /** * Getter for costFunction * @return * Cost function to parallelize */ public ParallelizableCostFunction getCostFunction() { return this.costFunction; } /** * Setter for costFunction * @param costFunction * Cost function to parallelize */ public void setCostFunction( ParallelizableCostFunction costFunction ) { this.costFunction = costFunction; this.evaluationComponents = null; this.gradientComponents = null; } /** * Splits the data across the numComponents cost functions */ protected void createPartitions() { int numThreads = this.getNumThreads(); ArrayList<ArrayList<InputOutputPair<? extends Vector, Vector>>> partitions = SequentialDataMultiPartitioner.create( this.getCostParameters(), numThreads ); this.evaluationComponents = new ArrayList<Callable<Object>>( numThreads ); this.gradientComponents = new ArrayList<Callable<Object>>( numThreads ); for( int i = 0; i < numThreads; i++ ) { ParallelizableCostFunction subcost = (ParallelizableCostFunction) this.getCostFunction().clone(); subcost.setCostParameters( partitions.get(i) ); this.evaluationComponents.add( new SubCostEvaluate( subcost, null ) ); this.gradientComponents.add( new SubCostGradient( subcost, null ) ); } } @Override public void setCostParameters( Collection<? extends InputOutputPair<? extends Vector, Vector>> costParameters ) { super.setCostParameters( costParameters ); this.evaluationComponents = null; this.gradientComponents = null; } @Override public Double evaluate( Evaluator<? super Vector, ? extends Vector> evaluator ) { if( this.evaluationComponents == null ) { this.createPartitions(); } // Set the subtasks for( Callable<Object> sce : this.evaluationComponents ) { ((SubCostEvaluate) sce).evaluator = evaluator; } Collection<Object> partialResults = null; try { partialResults = ParallelUtil.executeInParallel( this.evaluationComponents, this.getThreadPool() ); } catch (Exception ex) { Logger.getLogger( ParallelizedCostFunctionContainer.class.getName() ).log( Level.SEVERE, null, ex ); } return this.getCostFunction().evaluateAmalgamate( partialResults ); } @Override public Double evaluatePerformance( Collection<? extends TargetEstimatePair<? extends Vector, ? extends Vector>> data ) { return this.getCostFunction().evaluatePerformance( data ); } public Vector computeParameterGradient( GradientDescendable function ) { if (this.gradientComponents == null) { this.createPartitions(); } // Create the subtasks for (Callable<Object> eval : this.gradientComponents) { ((SubCostGradient) eval).evaluator = function; } Collection<Object> results = null; try { results = ParallelUtil.executeInParallel( this.gradientComponents, this.getThreadPool() ); } catch (Exception ex) { Logger.getLogger( ParallelizedCostFunctionContainer.class.getName() ).log( Level.SEVERE, null, ex ); } return this.getCostFunction().computeParameterGradientAmalgamate( results ); } public ThreadPoolExecutor getThreadPool() { if( this.threadPool == null ) { this.setThreadPool( ParallelUtil.createThreadPool() ); } return this.threadPool; } public void setThreadPool( ThreadPoolExecutor threadPool ) { this.threadPool = threadPool; } public int getNumThreads() { return ParallelUtil.getNumThreads( this ); } /** * Creates the thread pool using the Foundry's global thread pool. */ protected void createThreadPool() { this.setThreadPool( ParallelUtil.createThreadPool() ); } /** * Callable task for the evaluate() method. */ protected static class SubCostEvaluate implements Callable<Object> { /** * Parallel cost function */ private ParallelizableCostFunction costFunction; /** * Evaluator for which to compute the cost */ private Evaluator<? super Vector, ? extends Vector> evaluator; /** * Creates a new instance of SubCostEvaluate * @param costFunction * Parallel cost function * @param evaluator * Evaluator for which to compute the cost */ public SubCostEvaluate( ParallelizableCostFunction costFunction, Evaluator<? super Vector, ? extends Vector> evaluator ) { this.costFunction = costFunction; this.evaluator = evaluator; } public Object call() { return this.costFunction.evaluatePartial( this.evaluator ); } } /** * Callable task for the computeGradient() method */ protected static class SubCostGradient implements Callable<Object> { /** * Parallel cost function */ private ParallelizableCostFunction costFunction; /** * Function for which to compute the gradient */ private GradientDescendable evaluator; /** * Creates a new instance of SubCostGradient * @param costFunction * Parallel cost function * @param evaluator * Function for which to compute the gradient */ public SubCostGradient( ParallelizableCostFunction costFunction, GradientDescendable evaluator ) { this.costFunction = costFunction; this.evaluator = evaluator; } public Object call() { return this.costFunction.computeParameterGradientPartial( this.evaluator ); } } }