package edu.stanford.nlp.loglinear.learning; import edu.stanford.nlp.util.logging.Redwood; import edu.stanford.nlp.loglinear.model.ConcatVector; import edu.stanford.nlp.loglinear.model.GraphicalModel; import edu.stanford.nlp.util.RuntimeInterruptedException; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.lang.management.ManagementFactory; import java.util.ArrayList; import java.util.List; import java.util.Random; /** * Created on 8/26/15. * @author keenon * <p> * Abstract base of all the different kinds of optimizers. This exists to both facilitate sharing test between optimizers * and to share certain basic bits of functionality useful for batch optimizers, like intelligent multi-thread management * and user interrupt handling. */ public abstract class AbstractBatchOptimizer { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(AbstractBatchOptimizer.class); public <T> ConcatVector optimize(T[] dataset, AbstractDifferentiableFunction<T> fn) { return optimize(dataset, fn, new ConcatVector(0), 0.0, 1.0e-5, false); } public <T> ConcatVector optimize(T[] dataset, AbstractDifferentiableFunction<T> fn, ConcatVector initialWeights, double l2regularization, double convergenceDerivativeNorm, boolean quiet) { if (!quiet) log.info("\n**************\nBeginning training\n"); else log.info("[Beginning quiet training]"); TrainingWorker<T> mainWorker = new TrainingWorker<>(dataset, fn, initialWeights, l2regularization, convergenceDerivativeNorm, quiet); new Thread(mainWorker).start(); BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); if (!quiet) { log.info("NOTE: you can press any key (and maybe ENTER afterwards to jog stdin) to terminate learning early."); log.info("The convergence criteria are quite aggressive if left uninterrupted, and will run for a while"); log.info("if left to their own devices.\n"); while (true) { if (mainWorker.isFinished) { log.info("training completed without interruption"); return mainWorker.weights; } try { if (br.ready()) { log.info("received quit command: quitting"); log.info("training completed by interruption"); mainWorker.isFinished = true; return mainWorker.weights; } } catch (IOException e) { e.printStackTrace(); } } } else { while (!mainWorker.isFinished) { synchronized (mainWorker.naturalTerminationBarrier) { try { mainWorker.naturalTerminationBarrier.wait(); } catch (InterruptedException e) { throw new RuntimeInterruptedException(e); } } } log.info("[Quiet training complete]"); return mainWorker.weights; } } List<Constraint> constraints = new ArrayList<>(); /** * This adds a constraint on the weight vector, that a certain component must be set to a sparse index=value * * @param component the component to fix * @param index the index of the fixed sparse component * @param value the value to fix at */ public void addSparseConstraint(int component, int index, double value) { constraints.add(new Constraint(component, index, value)); } /** * This adds a constraint on the weight vector, that a certain component must be set to a dense array * * @param component the component to fix * @param arr the dense array to set */ public void addDenseConstraint(int component, double[] arr) { constraints.add(new Constraint(component, arr)); } /** * A way to record a constraint on the weight vector */ private static class Constraint { int component; boolean isSparse; int index; double value; double[] arr; public Constraint(int component, int index, double value) { isSparse = true; this.component = component; this.index = index; this.value = value; } public Constraint(int component, double[] arr) { isSparse = false; this.component = component; this.arr = arr; } public void applyToWeights(ConcatVector weights) { if (isSparse) { weights.setSparseComponent(component, index, value); } else { weights.setDenseComponent(component, arr); } } public void applyToDerivative(ConcatVector derivative) { if (isSparse) { derivative.setSparseComponent(component, index, 0.0); } else { derivative.setDenseComponent(component, new double[]{0.0}); } } } /** * This is the hook for subclassing batch optimizers to override in order to have their optimizer work. * * @param weights the current weights (update these in place) * @param gradient the gradient at these weights * @param logLikelihood the log likelihood at these weights * @param state any saved state the optimizer wants to keep and pass around during each optimization run * @param quiet whether or not to dump output about progress to the console * @return whether or not we've converged */ public abstract boolean updateWeights(ConcatVector weights, ConcatVector gradient, double logLikelihood, OptimizationState state, boolean quiet); /** * This is subclassed by children to store any state they need to perform optimization */ protected abstract class OptimizationState { } /** * This is called at the beginning of each batch optimization. It should return a fresh OptimizationState object that * will then be handed to updateWeights() on each update. * * @param initialWeights the initial weights for the optimizer to use * @return a fresh OptimizationState */ protected abstract OptimizationState getFreshOptimizationState(ConcatVector initialWeights); private static class GradientWorker<T> implements Runnable { ConcatVector localDerivative; double localLogLikelihood = 0.0; TrainingWorker mainWorker; int threadIdx; int numThreads; List<T> queue; AbstractDifferentiableFunction<T> fn; ConcatVector weights; long jvmThreadId = 0; // This is to help the dynamic re-balancing of work queues long finishedAtTime = 0; long cpuTimeRequired = 0; public GradientWorker(TrainingWorker<T> mainWorker, int threadIdx, int numThreads, List<T> queue, AbstractDifferentiableFunction<T> fn, ConcatVector weights) { this.mainWorker = mainWorker; this.threadIdx = threadIdx; this.numThreads = numThreads; this.queue = queue; this.fn = fn; this.weights = weights; localDerivative = weights.newEmptyClone(); } @Override public void run() { long startTime = ManagementFactory.getThreadMXBean().getThreadCpuTime(jvmThreadId); for (T datum : queue) { localLogLikelihood += fn.getSummaryForInstance(datum, weights, localDerivative); // Check for user interrupt if (mainWorker.isFinished) return; } finishedAtTime = System.currentTimeMillis(); long endTime = ManagementFactory.getThreadMXBean().getThreadCpuTime(jvmThreadId); cpuTimeRequired = endTime - startTime; } } private class TrainingWorker<T> implements Runnable { ConcatVector weights; OptimizationState optimizationState; boolean isFinished = false; boolean useThreads = Runtime.getRuntime().availableProcessors() > 1; T[] dataset; AbstractDifferentiableFunction<T> fn; double l2regularization; double convergenceDerivativeNorm; boolean quiet; final Object naturalTerminationBarrier = new Object(); public TrainingWorker(T[] dataset, AbstractDifferentiableFunction<T> fn, ConcatVector initialWeights, double l2regularization, double convergenceDerivativeNorm, boolean quiet) { optimizationState = getFreshOptimizationState(initialWeights); weights = initialWeights.deepClone(); this.dataset = dataset; this.fn = fn; this.l2regularization = l2regularization; this.convergenceDerivativeNorm = convergenceDerivativeNorm; this.quiet = quiet; } /** * This lets the system allocate work to threads evenly, which reduces the amount of blocking and can improve * runtimes by 20% or more. * * @param datum the datum to estimate work for * @return a work estimate, on a relative scale of single cpu wall time, for getting the gradient and log-likelihood */ private int estimateRelativeRuntime(T datum) { if (datum instanceof GraphicalModel) { int cost = 0; GraphicalModel model = (GraphicalModel) datum; for (GraphicalModel.Factor f : model.factors) { cost += f.featuresTable.combinatorialNeighborStatesCount(); } return cost; } else return 1; } @Override public void run() { // Multithreading stuff int numThreads = Math.max(1, Runtime.getRuntime().availableProcessors()); @SuppressWarnings("unchecked") List<T>[] queues = (List<T>[]) (new List[numThreads]); Random r = new Random(); // Allocate work to make estimated cost of work per thread as even as possible if (useThreads) { for (int i = 0; i < numThreads; i++) { queues[i] = new ArrayList<>(); } int[] queueEstimatedTotalCost = new int[numThreads]; for (T datum : dataset) { int datumEstimatedCost = estimateRelativeRuntime(datum); int minCostQueue = 0; for (int i = 0; i < numThreads; i++) { if (queueEstimatedTotalCost[i] < queueEstimatedTotalCost[minCostQueue]) minCostQueue = i; } queueEstimatedTotalCost[minCostQueue] += datumEstimatedCost; queues[minCostQueue].add(datum); } } while (!isFinished) { // Collect log-likelihood and derivatives long startTime = System.currentTimeMillis(); long threadWaiting = 0; ConcatVector derivative = weights.newEmptyClone(); double logLikelihood = 0.0; if (useThreads) { GradientWorker[] workers = new GradientWorker[numThreads]; Thread[] threads = new Thread[numThreads]; for (int i = 0; i < workers.length; i++) { workers[i] = new GradientWorker(this, i, numThreads, queues[i], fn, weights); threads[i] = new Thread(workers[i]); workers[i].jvmThreadId = threads[i].getId(); threads[i].start(); } // This is for logging long minFinishTime = Long.MAX_VALUE; long maxFinishTime = Long.MIN_VALUE; // This is for re-balancing long minCPUTime = Long.MAX_VALUE; long maxCPUTime = Long.MIN_VALUE; int slowestWorker = 0; int fastestWorker = 0; for (int i = 0; i < workers.length; i++) { try { threads[i].join(); } catch (InterruptedException e) { throw new RuntimeInterruptedException(e); } logLikelihood += workers[i].localLogLikelihood; derivative.addVectorInPlace(workers[i].localDerivative, 1.0); if (workers[i].finishedAtTime < minFinishTime) { minFinishTime = workers[i].finishedAtTime; } if (workers[i].finishedAtTime > maxFinishTime) { maxFinishTime = workers[i].finishedAtTime; } if (workers[i].cpuTimeRequired < minCPUTime) { fastestWorker = i; minCPUTime = workers[i].cpuTimeRequired; } if (workers[i].cpuTimeRequired > maxCPUTime) { slowestWorker = i; maxCPUTime = workers[i].cpuTimeRequired; } } threadWaiting = maxFinishTime - minFinishTime; // Try to reallocate work dynamically to minimize waiting on subsequent rounds // Figure out the percentage of work represented by the waiting double waitingPercentage = (double) (maxCPUTime - minCPUTime) / (double) maxCPUTime; int needTransferItems = (int) Math.floor(queues[slowestWorker].size() * waitingPercentage * 0.5); for (int i = 0; i < needTransferItems; i++) { int toTransfer = r.nextInt(queues[slowestWorker].size()); T datum = queues[slowestWorker].get(toTransfer); queues[slowestWorker].remove(toTransfer); queues[fastestWorker].add(datum); } // Check for user interrupt if (isFinished) return; } else { for (T datum : dataset) { assert (datum != null); logLikelihood += fn.getSummaryForInstance(datum, weights, derivative); // Check for user interrupt if (isFinished) return; } } logLikelihood /= dataset.length; derivative.mapInPlace((d) -> d / dataset.length); long gradientComputationTime = System.currentTimeMillis() - startTime; // Regularization logLikelihood = logLikelihood - (l2regularization * weights.dotProduct(weights)); derivative.addVectorInPlace(weights, -2 * l2regularization); // Zero out the derivative on the components we're holding fixed for (Constraint constraint : constraints) { constraint.applyToDerivative(derivative); } // If our derivative is sufficiently small, we've converged double derivativeNorm = derivative.dotProduct(derivative); if (derivativeNorm < convergenceDerivativeNorm) { if (!quiet) log.info("Derivative norm " + derivativeNorm + " < " + convergenceDerivativeNorm + ": quitting"); break; } // Do the actual computation if (!quiet) { log.info("[" + gradientComputationTime + " ms, threads waiting " + threadWaiting + " ms]"); } boolean converged = updateWeights(weights, derivative, logLikelihood, optimizationState, quiet); // Apply constraints to the weights vector for (Constraint constraint : constraints) { constraint.applyToWeights(weights); } if (converged) { break; } } synchronized (naturalTerminationBarrier) { naturalTerminationBarrier.notifyAll(); } isFinished = true; } } }