/* * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ /* * ParallelIteratedSingleClassifierEnhancer.java * Copyright (C) 2009-2012 University of Waikato, Hamilton, New Zealand * */ package weka.classifiers; import java.util.Enumeration; import java.util.Vector; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import weka.core.Instances; import weka.core.Option; import weka.core.Utils; /** * Abstract utility class for handling settings common to * meta classifiers that build an ensemble in parallel from a single * base learner. * * @author Mark Hall (mhall{[at]}pentaho{[dot]}com) * @version $Revision: 8034 $ */ public abstract class ParallelIteratedSingleClassifierEnhancer extends IteratedSingleClassifierEnhancer { /** For serialization */ private static final long serialVersionUID = -5026378741833046436L; /** The number of threads to have executing at any one time */ protected int m_numExecutionSlots = 1; /** Pool of threads to train models with */ protected transient ThreadPoolExecutor m_executorPool; /** The number of classifiers completed so far */ protected int m_completed; /** * The number of classifiers that experienced a failure of some sort * during construction */ protected int m_failed; /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(2); newVector.addElement(new Option( "\tNumber of execution slots.\n" + "\t(default 1 - i.e. no parallelism)", "num-slots", 1, "-num-slots <num>")); Enumeration enu = super.listOptions(); while (enu.hasMoreElements()) { newVector.addElement(enu.nextElement()); } return newVector.elements(); } /** * Parses a given list of options. Valid options are:<p> * * -Z num <br> * Set the number of execution slots to use (default 1 - i.e. no parallelism). <p> * * Options after -- are passed to the designated classifier.<p> * * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String iterations = Utils.getOption("num-slots", options); if (iterations.length() != 0) { setNumExecutionSlots(Integer.parseInt(iterations)); } else { setNumExecutionSlots(1); } super.setOptions(options); } /** * Gets the current settings of the classifier. * * @return an array of strings suitable for passing to setOptions */ public String [] getOptions() { String [] superOptions = super.getOptions(); String [] options = new String [superOptions.length + 2]; int current = 0; options[current++] = "-num-slots"; options[current++] = "" + getNumExecutionSlots(); System.arraycopy(superOptions, 0, options, current, superOptions.length); return options; } /** * Set the number of execution slots (threads) to use for building the * members of the ensemble. * * @param numSlots the number of slots to use. */ public void setNumExecutionSlots(int numSlots) { m_numExecutionSlots = numSlots; } /** * Get the number of execution slots (threads) to use for building * the members of the ensemble. * * @return the number of slots to use */ public int getNumExecutionSlots() { return m_numExecutionSlots; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numExecutionSlotsTipText() { return "The number of execution slots (threads) to use for " + "constructing the ensemble."; } /** * Stump method for building the classifiers * * @param data the training data to be used for generating the ensemble * @exception Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data) throws Exception { super.buildClassifier(data); if (m_numExecutionSlots < 1) { throw new Exception("Number of execution slots needs to be >= 1!"); } if (m_numExecutionSlots > 1) { startExecutorPool(); } m_completed = 0; m_failed = 0; } /** * Start the pool of execution threads */ protected void startExecutorPool() { if (m_executorPool != null) { m_executorPool.shutdownNow(); } m_executorPool = new ThreadPoolExecutor(m_numExecutionSlots, m_numExecutionSlots, 120, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>()); } private synchronized void block(boolean tf) { if (tf) { try { if (m_numExecutionSlots > 1 && m_completed + m_failed < m_Classifiers.length) { wait(); } } catch (InterruptedException ex) { } } else { notifyAll(); } } /** * Does the actual construction of the ensemble * * @throws Exception if something goes wrong during the training * process */ protected synchronized void buildClassifiers() throws Exception { for (int i = 0; i < m_Classifiers.length; i++) { if (m_numExecutionSlots > 1) { final Classifier currentClassifier = m_Classifiers[i]; final int iteration = i; if (m_Debug) { System.out.print("Training classifier (" + (i +1) + ")"); } Runnable newTask = new Runnable() { public void run() { try { currentClassifier.buildClassifier(getTrainingSet(iteration)); completedClassifier(iteration, true); } catch (Exception ex) { ex.printStackTrace(); completedClassifier(iteration, false); } } }; // launch this task m_executorPool.execute(newTask); } else { m_Classifiers[i].buildClassifier(getTrainingSet(i)); } } if (m_numExecutionSlots > 1 && m_completed + m_failed < m_Classifiers.length) { block(true); } } /** * Records the completion of the training of a single classifier. Unblocks if * all classifiers have been trained. * * @param iteration the iteration that has completed * @param success whether the classifier trained successfully */ protected synchronized void completedClassifier(int iteration, boolean success) { if (!success) { m_failed++; if (m_Debug) { System.err.println("Iteration " + iteration + " failed!"); } } else { m_completed++; } if (m_completed + m_failed == m_Classifiers.length) { if (m_failed > 0) { if (m_Debug) { System.err.println("Problem building classifiers - some iterations failed."); } } // have to shut the pool down or program executes as a server // and when running from the command line does not return to the // prompt m_executorPool.shutdown(); block(false); } } /** * Gets a training set for a particular iteration. Implementations need * to be careful with thread safety and should probably be synchronized * to be on the safe side. * * @param iteration the number of the iteration for the requested training set * @return the training set for the supplied iteration number * @throws Exception if something goes wrong. */ protected abstract Instances getTrainingSet(int iteration) throws Exception; }