/* * File: ParallelDirichletProcessMixtureModel.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright May 3, 2010, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.statistics.bayesian; import gov.sandia.cognition.algorithm.ParallelAlgorithm; import gov.sandia.cognition.algorithm.ParallelUtil; import gov.sandia.cognition.collection.CollectionUtil; import gov.sandia.cognition.statistics.bayesian.DirichletProcessMixtureModel.DPMMLogConditional; import gov.sandia.cognition.util.AbstractCloneableSerializable; import gov.sandia.cognition.util.ObjectUtil; import java.util.ArrayList; import java.util.Collection; import java.util.LinkedList; import java.util.concurrent.Callable; import java.util.concurrent.ThreadPoolExecutor; /** * A Parallelized version of vanilla Dirichlet Process Mixture Model learning. * In particular, this class parallelizes the assignment of observations to * clusters and the Gibbs sampling updating of clusters from their constituent * observations. * @param <ObservationType> * Type of observations handled by the algorithm * @author Kevin R. Dixon * @since 3.0 */ public class ParallelDirichletProcessMixtureModel<ObservationType> extends DirichletProcessMixtureModel<ObservationType> implements ParallelAlgorithm { /** * Thread pool used for parallelization. */ private transient ThreadPoolExecutor threadPool; /** * Creates a new instance of ParallelDirichletProcessMixtureModel */ public ParallelDirichletProcessMixtureModel() { super(); } public int getNumThreads() { return ParallelUtil.getNumThreads(this); } public ThreadPoolExecutor getThreadPool() { if (this.threadPool == null) { this.setThreadPool(ParallelUtil.createThreadPool()); } return this.threadPool; } public void setThreadPool( final ThreadPoolExecutor threadPool) { this.threadPool = threadPool; } /** * Tasks that assign observations to clusters */ transient protected ArrayList<ObservationAssignmentTask> assignmentTasks; @Override protected ArrayList<Collection<ObservationType>> assignObservationsToClusters( int K, DPMMLogConditional logConditional ) { if( this.assignmentTasks == null ) { ArrayList<? extends ObservationType> dataArray = CollectionUtil.asArrayList(this.data ); final int N = dataArray.size(); final int numThreads = this.getNumThreads(); this.assignmentTasks = new ArrayList<ObservationAssignmentTask>( numThreads ); int numPerTask = N/numThreads; int endIndex = 0; for( int n = 0; n < numThreads-1; n++ ) { int startIndex = endIndex; endIndex += numPerTask; this.assignmentTasks.add( new ObservationAssignmentTask( dataArray.subList(startIndex, endIndex) ) ); } this.assignmentTasks.add( new ObservationAssignmentTask( dataArray.subList(endIndex,N) ) ); } ArrayList<DPMMAssignments> results; try { results = ParallelUtil.executeInParallel( this.assignmentTasks, this.getThreadPool() ); } catch( Exception ex ) { throw new RuntimeException( ex ); } // This assigns observations to each of the K clusters, plus the // as-yet-uncreated new cluster ArrayList<Collection<ObservationType>> clusterAssignments = new ArrayList<Collection<ObservationType>>( K+1 ); for( int k = 0; k < K+1; k++ ) { clusterAssignments.add( new LinkedList<ObservationType>() ); } for( int n = 0; n < results.size(); n++ ) { logConditional.logConditional += results.get(n).logConditional.logConditional; ArrayList<Integer> assignments = results.get(n).assignments; int index = 0; for( ObservationType observation : this.assignmentTasks.get(n).observations ) { int assignment = assignments.get(index); clusterAssignments.get(assignment).add( observation ); index++; } } return clusterAssignments; } /** * Assignments from the DPMM */ public static class DPMMAssignments { /** * List of assignment indices */ protected ArrayList<Integer> assignments; /** * Log conditional likelihood of the previous sample */ protected DPMMLogConditional logConditional; /** * Constructor * @param assignments * List of assignment indices * @param logConditional * Log conditional likelihood of the previous sample */ public DPMMAssignments( ArrayList<Integer> assignments, DPMMLogConditional logConditional) { this.assignments = assignments; this.logConditional = logConditional; } } /** * Task that assign observations to cluster indices */ protected class ObservationAssignmentTask extends AbstractCloneableSerializable implements Callable<DPMMAssignments> { /** * Observations to assign */ private Collection<? extends ObservationType> observations; /** * Weights that are re-used */ private double[] weights; /** * Resulting assignments */ private ArrayList<Integer> assignments; /** * Log conditional of the previous sample */ private DPMMLogConditional logConditional; /** * Creates a new instance of ObservationAssignmentTask * @param observations * Observations to assign */ public ObservationAssignmentTask( Collection<? extends ObservationType> observations ) { this.weights = null; this.observations = observations; } public DPMMAssignments call() throws Exception { final int K = currentParameter.getNumClusters(); if( (this.weights == null) || (this.weights.length != K+1) ) { this.weights = new double[ K+1 ]; } if( this.assignments == null ) { this.assignments = new ArrayList<Integer>( this.observations.size() ); for( int n = 0; n < this.observations.size(); n++ ) { this.assignments.add( null ); } } this.logConditional = new DPMMLogConditional(); int index = 0; for( ObservationType observation : this.observations ) { int clusterAssignment = assignObservationToCluster( observation, this.weights, this.logConditional ); this.assignments.set( index, clusterAssignment ); index++; } return new DPMMAssignments(this.assignments, this.logConditional); } } /** * Tasks that update the values of the clusters for Gibbs sampling */ transient protected ArrayList<ClusterUpdaterTask> clusterUpdaterTasks; @Override protected ArrayList<DPMMCluster<ObservationType>> updateClusters( ArrayList<Collection<ObservationType>> clusterAssignments) { final int Kp1 = clusterAssignments.size(); if( (this.clusterUpdaterTasks == null) || (this.clusterUpdaterTasks.size() != Kp1) ) { this.clusterUpdaterTasks = new ArrayList<ClusterUpdaterTask>( Kp1 ); for( int k = 0; k < Kp1; k++ ) { this.clusterUpdaterTasks.add( new ClusterUpdaterTask() ); } } for( int k = 0; k < Kp1; k++ ) { Collection<ObservationType> observations = clusterAssignments.get(k); if( observations.size() <= 1 ) { observations = null; } this.clusterUpdaterTasks.get(k).observations = observations; } ArrayList<DPMMCluster<ObservationType>> clusters = null; try { clusters = ParallelUtil.executeInParallel( this.clusterUpdaterTasks, this.getThreadPool() ); } catch (Exception e) { throw new RuntimeException(e); } ArrayList<DPMMCluster<ObservationType>> results = new ArrayList<DPMMCluster<ObservationType>>( Kp1 ); for( int k = 0; k < Kp1; k++ ) { DPMMCluster<ObservationType> cluster = clusters.get(k); if( cluster != null ) { results.add( cluster ); } } return results; } /** * Tasks that update the values of the clusters for Gibbs sampling */ protected class ClusterUpdaterTask extends AbstractCloneableSerializable implements Callable<DPMMCluster<ObservationType>> { /** * Observations that comprise the cluster */ Collection<ObservationType> observations; /** * Local clone of the updater, needed to ensure thread safety */ Updater<ObservationType> localUpdater; /** * Creates a new instance of ClusterUpdaterTask */ public ClusterUpdaterTask() { this.localUpdater = ObjectUtil.cloneSafe( updater ); } public DPMMCluster<ObservationType> call() { return createCluster(this.observations, this.localUpdater ); } } }