package de.jungblut.online.minimizer; import java.util.Deque; import java.util.LinkedList; import java.util.Random; import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.StampedLock; import java.util.function.Supplier; import java.util.stream.Stream; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import com.google.common.base.Preconditions; import com.google.common.base.Stopwatch; import de.jungblut.math.DoubleVector; import de.jungblut.math.minimize.CostGradientTuple; import de.jungblut.online.ml.FeatureOutcomePair; import de.jungblut.online.regularization.CostWeightTuple; import de.jungblut.online.regularization.GradientDescentUpdater; import de.jungblut.online.regularization.WeightUpdater; /** * Stochastic gradient descent. This class is designed to work on a parallel * stream and do stochastic updates to a parameter set. * * @author thomas.jungblut * */ public class StochasticGradientDescent implements StochasticMinimizer { private static final Logger LOG = LogManager .getLogger(StochasticGradientDescent.class); public static class StochasticGradientDescentBuilder { private final double alpha; private double breakDifference; private double momentum; private int historySize = 10; private int progressReportInterval = 1; private double holdoutValidationPercentage = 0d; private boolean adaptiveLearningRate = false; private WeightUpdater weightUpdater = new GradientDescentUpdater(); private long validationRandomSeed = System.currentTimeMillis(); private StochasticGradientDescentBuilder(double alpha) { this.alpha = alpha; } public StochasticGradientDescent build() { return new StochasticGradientDescent(this); } /** * Add momentum to this gradient descent minimizer. * * @param momentum the momentum to use. Between 0 and 1. * @return the builder again. */ public StochasticGradientDescentBuilder momentum(double momentum) { Preconditions.checkArgument(momentum >= 0d && momentum <= 1d, "Momentum must be between 0 and 1."); this.momentum = momentum; return this; } /** * In order to fix the reproducibility of a given train/test set split, you * can pass the seed value. * * @param seed the seed as passed into {@link java.util.Random}. * @return the builder again. */ public StochasticGradientDescentBuilder validationRandomSeed(long seed) { this.validationRandomSeed = seed; return this; } /** * Holdout validation percentage, this will take a subset of the data on the * stream and do a validation on it. * * @param perc the percentage to use. Between 0 and 1. * @return the builder again. */ public StochasticGradientDescentBuilder holdoutValidationPercentage( double perc) { Preconditions.checkArgument(momentum >= 0d && momentum <= 1d, "HoldOut Percentage must be between 0 and 1."); this.holdoutValidationPercentage = perc; return this; } /** * Sets the weight updater, for example to use regularization. The default * is the normal gradient descent. * * To set the regularization parameter use the {@link #lambda(double)} * method. * * @param weightUpdater the updater to use. * @return the builder again. */ public StochasticGradientDescentBuilder weightUpdater( WeightUpdater weightUpdater) { this.weightUpdater = Preconditions.checkNotNull(weightUpdater); return this; } /** * Sets the size of the history to keep to compute average improvements and * output progress information. * * @return the builder again. */ public StochasticGradientDescentBuilder historySize(int historySize) { Preconditions.checkArgument(historySize > 0, "HistorySize must be > 0"); this.historySize = historySize; return this; } /** * Sets the progress report interval. Since writing to the console/log might * be expensive, this is an easy way to limit the logging if needed. * * @param interval the interval. E.g. every 10th iteration. * @return the builder again. */ public StochasticGradientDescentBuilder progressReportInterval(int interval) { Preconditions.checkArgument(interval > 0, "ReportInterval must be > 0"); this.progressReportInterval = interval; return this; } /** * Breaks minimization process when the given delta in costs have been * archieved. Usually a quite low value of 1e-4 to 1e-8. * * @param delta the delta to break in difference between two costs. * @return the builder again. */ public StochasticGradientDescentBuilder breakOnDifference(double delta) { this.breakDifference = delta; return this; } /** * Enables adaptive learning rate, using the algorithm: <br/> * * <pre> * alpha = 1d / (initialAlpha * (allIterations + 2)); * </pre> * * where allIterations is a counter over all passes. * * @return the builder again */ public StochasticGradientDescentBuilder enableAdaptiveLearningRate() { this.adaptiveLearningRate = true; return this; } /** * Creates a new builder. * * @param alpha the learning rate to set. * @return a new builder. */ public static StochasticGradientDescentBuilder create(double alpha) { return new StochasticGradientDescentBuilder(alpha); } } private final StochasticGradientDescentBuilder builder; private final long validationSeed; private IterationFinishedCallback iterationCallback; private ValidationFinishedCallback validationCallback; private PassFinishedCallback passCallback; private double breakDifference; private double momentum; private double initialAlpha; private double validationPercentage; private int historySize; private int progressReportInterval; private WeightUpdater weightUpdater; private StampedLock lock = new StampedLock(); // we are fixing the random for validation to generate the same sequences // to not mix train and validation set. private Random validationRandom; private Deque<Double> costHistory; private DoubleVector lastTheta = null; private DoubleVector theta; private double alpha; private int validationItems; private double validationError; private double trainingError; private boolean stopAfterThisPass = false; private boolean adaptiveLearningRate = false; private long iteration = 0; private long allIterations = 0; private Stopwatch startWatch; private StochasticGradientDescent(StochasticGradientDescentBuilder builder) { this.builder = builder; this.validationSeed = builder.validationRandomSeed; resetState(builder); } private void resetState(StochasticGradientDescentBuilder builder) { this.initialAlpha = builder.alpha; this.alpha = this.initialAlpha; this.breakDifference = builder.breakDifference; this.momentum = builder.momentum; this.progressReportInterval = builder.progressReportInterval; this.historySize = builder.historySize; this.weightUpdater = builder.weightUpdater; this.validationPercentage = builder.holdoutValidationPercentage; this.adaptiveLearningRate = builder.adaptiveLearningRate; this.costHistory = new LinkedList<>(); } @Override public DoubleVector minimize(DoubleVector start, Supplier<Stream<FeatureOutcomePair>> streamSupplier, StochasticCostFunction costFunction, int numPasses, boolean verbose) { resetState(builder); theta = start; startWatch = Stopwatch.createStarted(); for (int pass = 0; pass < numPasses; pass++) { validationRandom = new Random(validationSeed); iteration = 0; trainingError = 0; validationError = 0; validationItems = 0; Stream<FeatureOutcomePair> currentStream = streamSupplier.get(); final int passFinal = pass; if (currentStream.isParallel()) { currentStream.forEach((next) -> doStepLocked(passFinal, next, costFunction, verbose)); } else { currentStream.forEach((next) -> doStep(passFinal, next, costFunction, verbose)); } if (verbose) { LOG.info(String .format( "Pass Summary %d | Iteration %d | Validation Cost: %g | Training Cost: %g | Iterations/s: %g | Total Time Taken: %s", pass, iteration, validationError / Math.max(validationItems, 1), trainingError / Math.max(iteration - validationItems, 1), allIterations / (double) Math.max(startWatch.elapsed(TimeUnit.SECONDS), 1), startWatch)); } if (passCallback != null) { boolean continuePass = passCallback.onPassFinished(pass, iteration, validationError, theta); // break this pass, because the callback said so if (!continuePass) { break; } } if (stopAfterThisPass) { break; } } return theta; } // TODO this write lock is huge, can it be broken down more? private void doStepLocked(int pass, FeatureOutcomePair next, StochasticCostFunction costFunction, boolean verbose) { Lock writeLock = lock.asWriteLock(); try { writeLock.lock(); doStep(pass, next, costFunction, verbose); } finally { writeLock.unlock(); } } private void doStep(int pass, FeatureOutcomePair next, StochasticCostFunction costFunction, boolean verbose) { DoubleVector iterationLocalTheta = Preconditions.checkNotNull(weightUpdater .prePredictionWeightUpdate(next, theta, alpha, allIterations), "weight updater #prePredictionWeightUpdate return must be non-null!"); CostGradientTuple observed = costFunction.observeExample(next, iterationLocalTheta); if (verbose) { double avgImprovement = getAverageImprovement(costHistory); if (iteration > 0 && iteration % progressReportInterval == 0) { LOG.info(String .format( "Pass %d | Iteration %d | Validation Cost: %g | Training Cost: %g | Avg Improvement: %g | Iterations/s: %g", pass, iteration, validationError / Math.max(validationItems, 1), trainingError / Math.max(iteration - validationItems, 1), avgImprovement, allIterations / (double) Math.max(startWatch.elapsed(TimeUnit.SECONDS), 1))); } } dropOldValues(costHistory); boolean validation = false; if (validationPercentage > 0) { if (validationRandom.nextDouble() < validationPercentage) { validationError += observed.getCost(); validationItems++; // update the history costHistory.addLast(validationError / Math.max(validationItems, 1)); validation = true; if (validationCallback != null) { validationCallback.onValidationFinished(pass, iteration, observed.getCost(), iterationLocalTheta, next); } } } else { costHistory.addLast(observed.getCost() / Math.max(iteration, 1)); } if (iterationCallback != null) { iterationCallback.onIterationFinished(pass, iteration, observed.getCost(), iterationLocalTheta, validation); } if (validation) { // return to not update the parameters when we did a validation step return; } trainingError += observed.getCost(); CostWeightTuple update = updateWeights(iterationLocalTheta, observed); // save our last parameter lastTheta = iterationLocalTheta; theta = update.getWeight(); computeMomentum(); // break if we converged below the limit if (converged(costHistory, breakDifference)) { stopAfterThisPass = true; } allIterations++; iteration++; if (adaptiveLearningRate) { alpha = 1d / (initialAlpha * (allIterations + 2)); } } public void computeMomentum() { // compute momentum if (lastTheta != null && momentum != 0d) { // we add momentum as the parameter "m" multiplied by the // difference of both theta vectors theta = theta.add((lastTheta.subtract(theta)).multiply(momentum)); } } public CostWeightTuple updateWeights(DoubleVector iterationLocalTheta, CostGradientTuple observed) { return weightUpdater.computeNewWeights(iterationLocalTheta, observed.getGradient(), alpha, allIterations, observed.getCost()); } public void setIterationCallback(IterationFinishedCallback iterationCallback) { this.iterationCallback = iterationCallback; } public void setValidationCallback( ValidationFinishedCallback validationCallback) { this.validationCallback = validationCallback; } public void setPassCallback(PassFinishedCallback passCallback) { this.passCallback = passCallback; } // TODO this should use a cyclic buffer instead of a deque private void dropOldValues(Deque<Double> lastCosts) { while (lastCosts.size() > historySize) { lastCosts.pop(); } } private boolean converged(Deque<Double> lastCosts, double limit) { return Math.abs(getAverageImprovement(lastCosts)) < limit; } private double getAverageImprovement(Deque<Double> lastCosts) { if (lastCosts.size() >= 2) { double first = lastCosts.peek(); double value = lastCosts.peekLast(); return (value - first) / lastCosts.size(); } return 0d; } }