/** * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dtrain.nn; import java.util.Arrays; import java.util.concurrent.Callable; import java.util.concurrent.CompletionService; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorCompletionService; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import ml.shifu.shifu.core.dtrain.dataset.FloatFlatNetwork; import ml.shifu.shifu.core.dtrain.dataset.FloatMLDataSet; import org.encog.mathutil.error.ErrorCalculation; import org.encog.neural.error.ErrorFunction; import org.encog.neural.flat.FlatNetwork; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * {@link ParallelGradient} is copied from Encog framework. The reason is that we original Gradient don't pop up * gradients outside. While we need gradients accumulated into {@link NNMaster} to update NN weights. */ public class ParallelGradient { protected static final Logger LOG = LoggerFactory.getLogger(ParallelGradient.class); /** * The network to train. */ private FloatFlatNetwork network; /** * The training data. */ private final FloatMLDataSet training; /** * The testing data, test data set here is used for training and testing cross over. */ private final FloatMLDataSet testing; /** * Whether to replace training and testing elements. */ private final boolean isCrossOver; /** * Seed used to sample training and testing data set to choose which element is used for training */ private long seed = System.currentTimeMillis(); /** * Derivative add constant. Used to combat flat spot. */ private double[] flatSpot; /** * The error function to use. */ private final ErrorFunction errorFunction; private final int threadCount; private long[] trainLows; private long[] trainHighs; private double trainError; @SuppressWarnings("unused") private double testError; private long[] testLows; private long[] testHighs; private SubGradient[] subGradients; /** * Create a thread pool to do gradient computing and test set error computing using multiple threads. */ private ExecutorService threadPool; /** * If enabled by extreme learning machine: https://en.wikipedia.org/wiki/Extreme_learning_machine */ private boolean isELM; public ParallelGradient(final FloatFlatNetwork theNetwork, final FloatMLDataSet theTraining, final FloatMLDataSet theTesting, final double[] flatSpot, ErrorFunction ef, boolean isCrossOver, int threadCount, boolean isELM) { this.isELM = isELM; assert threadCount > 0 && threadCount < 33; this.threadCount = threadCount; this.training = theTraining; long recordCount = this.training.getRecordCount(); this.trainLows = new long[threadCount]; this.trainHighs = new long[threadCount]; // TODO not very good for such case: 80% in memory, 20% in disk, while all in disk are split into one thread long stepCount = recordCount / threadCount; if(recordCount % threadCount != 0) { // move step count to append last gap to avoid last thread worse 2*stepCount-1 stepCount += (recordCount % threadCount) / stepCount; } for(int i = 0; i < threadCount; i++) { this.trainLows[i] = i * stepCount; if(i != threadCount - 1) { this.trainHighs[i] = this.trainLows[i] + stepCount - 1; } else { this.trainHighs[i] = recordCount - 1; } } LOG.info("Train record count: {}", recordCount); LOG.info("Train lows: {}", Arrays.toString(trainLows)); LOG.info("Train highs: {}", Arrays.toString(trainHighs)); this.testing = theTesting; long testRecordCount = this.testing.getRecordCount(); this.testLows = new long[threadCount]; this.testHighs = new long[threadCount]; long testStepCount = testRecordCount / threadCount; if(testRecordCount % threadCount != 0) { // move step count to append last gap to avoid last thread worse 2*testStepCount-1 testStepCount += (testRecordCount % threadCount) / testStepCount; } for(int i = 0; i < threadCount; i++) { this.testLows[i] = i * testStepCount; if(i != threadCount - 1) { this.testHighs[i] = this.testLows[i] + testStepCount - 1; } else { this.testHighs[i] = testRecordCount - 1; } } LOG.info("Test record count: {}", testRecordCount); LOG.info("Test lows: {}", Arrays.toString(testLows)); LOG.info("Test highs: {}", Arrays.toString(testHighs)); this.network = theNetwork; this.isCrossOver = isCrossOver; this.flatSpot = flatSpot; this.errorFunction = ef; this.threadPool = Executors.newFixedThreadPool(this.threadCount); } public double[] computeGradients() { CompletionService<double[]> completionService = new ExecutorCompletionService<double[]>(this.threadPool); this.subGradients = new SubGradient[this.threadCount]; for(int i = 0; i < this.threadCount; i++) { if(this.subGradients[i] == null) { this.subGradients[i] = new SubGradient(this.network.clone(), this.training, this.trainLows[i], this.trainHighs[i], this.testing, this.testLows[i], this.testHighs[i], this.flatSpot, this.errorFunction, this.isCrossOver, this); } else { this.subGradients[i].setNetwork(this.network.clone()); } this.subGradients[i].setSeed(this.getSeed()); completionService.submit(this.subGradients[i]); } int rCnt = 0; double[] finalGradients = new double[this.getNetwork().getWeights().length]; while(rCnt < this.threadCount) { double[] gradients = null; try { gradients = completionService.take().get(); } catch (ExecutionException e) { throw new RuntimeException(e); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } for(int i = 0; i < finalGradients.length; i++) { finalGradients[i] += gradients[i]; } rCnt += 1; } double errorSum = 0d; for(int i = 0; i < this.threadCount; i++) { errorSum += this.subGradients[i].getError() * (trainHighs[i] - trainLows[i] + 1) * this.getNetwork().getOutputCount(); } this.trainError = errorSum / (this.training.getRecordCount() * this.getNetwork().getOutputCount()); return finalGradients; } /** * @return the seed */ public long getSeed() { return seed; } /** * @param seed * the seed to set */ public void setSeed(long seed) { this.seed = seed; } /** * @return the trainError */ public double getTrainError() { return trainError; } /** * @return the network */ public FlatNetwork getNetwork() { return network; } public double calculateError() { CompletionService<Double> completionService = new ExecutorCompletionService<Double>(this.threadPool); final ErrorCalculation ec = new ErrorCalculation(); for(int i = 0; i < this.threadCount; i++) { final SubGradient subGradient = this.subGradients[i]; completionService.submit(new Callable<Double>() { @Override public Double call() throws Exception { return subGradient.calculateError(ec); } }); } int rCnt = 0; while(rCnt < this.threadCount) { try { completionService.take().get(); } catch (ExecutionException e) { throw new RuntimeException(e); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } rCnt += 1; } return ec.calculate(); } /** * Average weights for all sub gradients and then set to current network. */ public void resetNetworkWeights() { double[] weights = new double[this.network.getWeights().length]; for(int i = 0; i < subGradients.length; i++) { double[] subWeights = subGradients[i].getNetwork().getWeights(); for(int j = 0; j < weights.length; j++) { weights[j] += subWeights[j]; } } for(int j = 0; j < weights.length; j++) { weights[j] /= subGradients.length; } this.network.setWeights(weights); } /** * Shut down thread pool, should be called at last to make sure jvm exit */ public void shutdown() { this.threadPool.shutdownNow(); try { this.threadPool.awaitTermination(2, TimeUnit.SECONDS); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } } /** * @return the isELM */ public boolean isELM() { return isELM; } }