package edu.berkeley.nlp.math; import edu.berkeley.nlp.mapper.AsynchronousMapper; import edu.berkeley.nlp.mapper.SimpleMapper; import edu.berkeley.nlp.util.CallbackFunction; import edu.berkeley.nlp.util.CollectionUtils; import edu.berkeley.nlp.util.Logger; //import edu.berkeley.nlp.util.optionparser.GlobalOptionParser; import edu.berkeley.nlp.util.Option; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Random; /** * User: aria42 * Date: Mar 10, 2009 */ public class StochasticObjectiveOptimizer<I> { Collection<I> items; List<? extends ObjectiveItemDifferentiableFunction<I>> itemFns; Regularizer regularizer; double initAlpha = 0.5; double upAlphaMult = 1.1; double downAlphaMult = 0.5; Object weightLock = new Object(); double[] weights; double alpha ; CallbackFunction iterDoneCallback; boolean printProgress = true; Random rand ; @Option public int randSeed = 0; @Option public boolean doAveraging = false; @Option public boolean shuffleData = false; double[] sumWeightVector; int numUpdates ; public StochasticObjectiveOptimizer(double initAlpha, double upAlphaMult, double downAlphaMult) { this(initAlpha,upAlphaMult,downAlphaMult,true); } public StochasticObjectiveOptimizer(double initAlpha, double upAlphaMult, double downAlphaMult, boolean printProgress) { this.initAlpha = initAlpha; this.upAlphaMult = upAlphaMult; this.downAlphaMult = downAlphaMult; this.printProgress = printProgress; // GlobalOptionParser.fillOptions(this); rand = new Random(randSeed); } public void setIterationCallback(CallbackFunction iterDoneCallback) { this.iterDoneCallback = iterDoneCallback; } // Do a pass through the data of SGD class GradMapper implements SimpleMapper<I> { double val = 0.0; ObjectiveItemDifferentiableFunction<I> itemFn; GradMapper(ObjectiveItemDifferentiableFunction<I> itemFn) { this.itemFn = itemFn; } public void map(I elem) { double[] localWeights; synchronized (weightLock) { localWeights = DoubleArrays.clone(weights); } double[] localGrad = new double[dimension()]; itemFn.setWeights(localWeights); val += itemFn.update(elem,localGrad); if (regularizer != null) { val += regularizer.update(localWeights,localGrad,1.0/items.size()); } synchronized (weightLock) { DoubleArrays.addInPlace(weights,localGrad, -alpha); DoubleArrays.addInPlace(sumWeightVector, weights); numUpdates++; } } } // Compute the function value for a fixed set of parameters class ValMapper implements SimpleMapper<I> { double val = 0.0; ObjectiveItemDifferentiableFunction<I> itemFn; ValMapper(ObjectiveItemDifferentiableFunction<I> itemFn) { this.itemFn = itemFn; } public void map(I elem) { val += itemFn.update(elem, null); val += regularizer.val(weights,1.0/items.size()); } } private double doIter() { List<GradMapper> gradMappers = new ArrayList<GradMapper>(); for (ObjectiveItemDifferentiableFunction<I> itemFn : itemFns) { gradMappers.add(new GradMapper(itemFn)); } List<I> shuffledItems = shuffleData ? CollectionUtils.shuffle(items,rand) : new ArrayList<I>(items); AsynchronousMapper.doMapping(shuffledItems,gradMappers); // List<ValMapper> valMappers = new ArrayList<ValMapper>(); // for (ObjectiveItemDifferentiableFunction<I> itemFn : itemFns) { // valMappers.add(new ValMapper(itemFn)); // } // AsynchronousMapper.doMapping(items,valMappers); double val = 0.0; for (GradMapper mapper : gradMappers) { val += mapper.val; } return val; } public double[] minimize(double[] initWeights, int numIters, Collection<I> items, List<? extends ObjectiveItemDifferentiableFunction<I>> itemFns, Regularizer regularizer) { this.items = items; this.itemFns = itemFns; this.numUpdates = 0; this.regularizer = regularizer; alpha = initAlpha; weights = DoubleArrays.clone(initWeights); sumWeightVector = DoubleArrays.constantArray(0.0, weights.length); double lastVal = Double.POSITIVE_INFINITY; for (int iter = 0; iter < numIters; iter++) { double val = doIter(); double alphaMult = val < lastVal ? upAlphaMult : downAlphaMult; alpha *= alphaMult; lastVal = val; if (printProgress) { Logger.logs("[StochasticObjectiveOptimizer] Ended Iteration %d with value %.5f",iter+1,val); Logger.logs("[StochasticObjectiveOptimizer] New Alpha: %.5f (scaled by %.5f)",alpha,alphaMult); } if (iterDoneCallback != null) { iterDoneCallback.callback(iter,doAveraging ? avgWeightVector() : weights,val,alpha); } if (alpha < initAlpha*Math.pow(10.0, -2.0)) { Logger.logs("[StochasticObjectiveOptimizer] alpha %.5f below tolerance %.5f, saying converged", alpha, initAlpha*Math.pow(10.0, -2.0)); break; } } return doAveraging ? avgWeightVector() : weights; } private double[] avgWeightVector() { double[] avgWeights = DoubleArrays.clone(sumWeightVector); DoubleArrays.scale(avgWeights,1.0/numUpdates); return avgWeights; } public int dimension() { return itemFns.get(0).dimension(); } }