package edu.stanford.nlp.optimization;
import edu.stanford.nlp.util.logging.Redwood;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.util.Timing;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.*;
/**
* Stochastic Gradient Descent With AdaGrad and FOBOS in batch mode.
* Optionally, user can also turn on AdaDelta via option "useAdaDelta"
* Similar to SGDMinimizer, regularization is done in the minimizer, not in the objective function.
* This version is not efficient for online setting. For online variant, consider implementing SparseAdaGradMinimizer.java
*
* @author Mengqiu Wang
*/
public class SGDWithAdaGradAndFOBOS<T extends DiffFunction> implements Minimizer<T>, HasEvaluators {
/** A logger for this class */
private static Redwood.RedwoodChannels log = Redwood.channels(SGDWithAdaGradAndFOBOS.class);
protected double[] x;
protected double initRate; // Initial stochastic iteration count
protected double lambda;
// when alpha = 1, sg-lasso is just lasso; when alpha = 0, sg-lasso is g-lasso
protected double alpha = 1.0;
protected boolean quiet = false;
private static final int DEFAULT_NUM_PASSES = 50;
protected final int numPasses; //-1;
protected int bSize = 1; // NOTE: If bSize does not divide evenly into total number of samples,
// some samples may get accounted for twice in one pass
private static final int DEFAULT_TUNING_SAMPLES = Integer.MAX_VALUE;
private static final int DEFAULT_BATCH_SIZE = 1000;
private double eps = 1e-3;
private double TOL = 1e-4;
// fields for approximating Hessian to feed to QN
public List<double[]> yList = null;
public List<double[]> sList = null;
public double[] diag;
private int hessSampleSize = -1;
private double[] s,y = null;
protected Random gen = new Random(1);
protected long maxTime = Long.MAX_VALUE;
private int evaluateIters = 0; // Evaluate every x iterations (0 = no evaluation)
private Evaluator[] evaluators; // separate set of evaluators to check how optimization is going
private Prior prior = Prior.LASSO;
private boolean useEvalImprovement = false;
private boolean useAvgImprovement = false;
private boolean suppressTestPrompt = false;
private int terminateOnEvalImprovementNumOfEpoch = 1;
private double bestEvalSoFar = Double.NEGATIVE_INFINITY;
private double[] xBest;
private int noImproveItrCount = 0;
private boolean useAdaDelta = false;
private boolean useAdaDiff = false;
private double rho = 0.95;
private double[] sumGradSquare;
private double[] prevGrad, prevDeltaX;
private double[] sumDeltaXSquare;
public void setHessSampleSize(int hessSize) {
this.hessSampleSize = hessSize;
// (TODO) should initialize relevant data structure as well
}
public void terminateOnEvalImprovement(boolean toTerminate) {
useEvalImprovement = toTerminate;
}
public void terminateOnAvgImprovement(boolean toTerminate, double tolerance) {
useAvgImprovement = toTerminate;
TOL = tolerance;
}
public void suppressTestPrompt(boolean suppressTestPrompt) {
this.suppressTestPrompt = suppressTestPrompt;
}
public void setTerminateOnEvalImprovementNumOfEpoch(int terminateOnEvalImprovementNumOfEpoch) {
this.terminateOnEvalImprovementNumOfEpoch = terminateOnEvalImprovementNumOfEpoch;
}
public boolean toContinue(double[] x, double currEval) {
if (currEval >= bestEvalSoFar) {
bestEvalSoFar = currEval;
noImproveItrCount = 0;
if (xBest == null)
xBest = Arrays.copyOf(x, x.length);
else
System.arraycopy( x, 0, xBest, 0, x.length );
return true;
} else {
noImproveItrCount += 1;
return noImproveItrCount <= terminateOnEvalImprovementNumOfEpoch;
}
}
public enum Prior {
LASSO, RIDGE, GAUSSIAN, aeLASSO, gLASSO, sgLASSO, NONE
}
private static Prior getPrior(String priorType) {
switch (priorType) {
case "none":
return Prior.NONE;
case "lasso":
return Prior.LASSO;
case "ridge":
return Prior.RIDGE;
case "gaussian":
return Prior.GAUSSIAN;
case "ae-lasso":
return Prior.aeLASSO;
case "g-lasso":
return Prior.gLASSO;
case "sg-lasso":
return Prior.sgLASSO;
default:
throw new IllegalArgumentException("prior type " + priorType + " not recognized; supported priors " +
"are: lasso, ridge, gaussian, ae-lasso, g-lasso, and sg-lasso");
}
}
public SGDWithAdaGradAndFOBOS(double initRate, double lambda, int numPasses) {
this(initRate, lambda, numPasses, -1);
}
public SGDWithAdaGradAndFOBOS(double initRate, double lambda, int numPasses, int batchSize) {
this(initRate, lambda, numPasses, batchSize, "lasso", 1.0, false, false, 1e-3, 0.95);
}
public SGDWithAdaGradAndFOBOS(double initRate, double lambda, int numPasses,
int batchSize, String priorType, double alpha, boolean useAdaDelta, boolean useAdaDiff, double adaGradEps, double adaDeltaRho)
{
this.initRate = initRate;
this.prior = getPrior(priorType);
this.bSize = batchSize;
this.lambda = lambda;
this.eps = adaGradEps;
this.rho = adaDeltaRho;
this.useAdaDelta = useAdaDelta;
this.useAdaDiff = useAdaDiff;
this.alpha = alpha;
if (numPasses >= 0) {
this.numPasses = numPasses;
} else {
this.numPasses = DEFAULT_NUM_PASSES;
sayln(" SGDWithAdaGradAndFOBOS: numPasses=" + numPasses + ", defaulting to " + this.numPasses);
}
}
public void shutUp() {
this.quiet = true;
}
private static final NumberFormat nf = new DecimalFormat("0.000E0");
protected String getName() {
return "SGDWithAdaGradAndFOBOS" + bSize + "_lambda" + nf.format(lambda) + "_alpha" + nf.format(alpha);
}
@Override
public void setEvaluators(int iters, Evaluator[] evaluators)
{
this.evaluateIters = iters;
this.evaluators = evaluators;
}
// really this is the the L2 norm....
private static double getNorm(double[] w)
{
double norm = 0;
for (double aW : w) {
norm += aW * aW;
}
return Math.sqrt(norm);
}
private double doEvaluation(double[] x) {
// Evaluate solution
if (evaluators == null) return Double.NEGATIVE_INFINITY;
double score = Double.NEGATIVE_INFINITY;
for (Evaluator eval:evaluators) {
if (!suppressTestPrompt)
sayln(" Evaluating: " + eval.toString());
double aScore = eval.evaluate(x);
if (aScore != Double.NEGATIVE_INFINITY)
score = aScore;
}
return score;
}
private static double pospart(double number) {
return number > 0.0 ? number : 0.0;
}
/*
private void approxHessian(double[] newX) {
for(int i = 0; i < x.length; i++){
double thisGain = fixedGain*gainSchedule(k,5*numBatches)/(diag[i]);
newX[i] = x[i] - thisGain*grad[i];
}
//Get a new pair...
say(" A ");
if (hessSampleSize > 0 && sList.size() == hessSampleSize || sList.size() == hessSampleSize) {
s = sList.remove(0);
y = yList.remove(0);
} else {
s = new double[x.length];
y = new double[x.length];
}
s = prevDeltaX;
s = ArrayMath.pairwiseSubtract(newX, x);
dfunction.recalculatePrevBatch = true;
System.arraycopy(dfunction.derivativeAt(newX,bSize),0,y,0,grad.length);
ArrayMath.pairwiseSubtractInPlace(y,newGrad); // newY = newY-newGrad
double[] comp = new double[x.length];
sList.add(s);
yList.add(y);
ScaledSGDMinimizer.updateDiagBFGS(diag,s,y);
}
*/
private double computeLearningRate(int index, double grad) {
// double eps2 = 1e-12;
double currentRate = Double.NEGATIVE_INFINITY;
double prevG = prevGrad[index];
double gradDiff = grad-prevG;
if (useAdaDelta) {
double deltaXt = prevDeltaX[index];
sumDeltaXSquare[index] = sumDeltaXSquare[index] * rho + (1-rho) * deltaXt * deltaXt;
if (useAdaDiff) {
sumGradSquare[index] = sumGradSquare[index] * rho + (1-rho) * (gradDiff) * (gradDiff);
} else {
sumGradSquare[index] = sumGradSquare[index] * rho + (1-rho) * grad * grad;
}
// double nominator = initRate;
// if (sumDeltaXSquare[index] > 0) {
// nominator = Math.sqrt(sumDeltaXSquare[index]+eps);
// }
// currentRate = nominator / Math.sqrt(sumGradSquare[index]+eps);
currentRate = Math.sqrt(sumDeltaXSquare[index]+eps) / Math.sqrt(sumGradSquare[index]+eps);
// double deltaXt = currentRate * grad;
// sumDeltaXSquare[index] = sumDeltaXSquare[index] * rho + (1-rho) * deltaXt * deltaXt;
} else {
if (useAdaDiff) {
sumGradSquare[index] += gradDiff * gradDiff;
} else {
sumGradSquare[index] += grad * grad;
}
// apply AdaGrad
currentRate = initRate / Math.sqrt(sumGradSquare[index]+eps);
}
// prevDeltaX[index] = grad * currentRate;
return currentRate;
}
private void updateX(double[] x, int index, double realUpdate) {
prevDeltaX[index] = realUpdate - x[index];
x[index] = realUpdate;
}
@Override
public double[] minimize(DiffFunction function, double functionTolerance, double[] initial) {
return minimize(function, functionTolerance, initial, -1);
}
@Override
public double[] minimize(DiffFunction f, double functionTolerance, double[] initial, int maxIterations) {
int totalSamples = 0;
sayln("Using lambda=" + lambda);
if (f instanceof AbstractStochasticCachingDiffUpdateFunction) {
AbstractStochasticCachingDiffUpdateFunction func = (AbstractStochasticCachingDiffUpdateFunction) f;
func.sampleMethod = AbstractStochasticCachingDiffFunction.SamplingMethod.Shuffled;
totalSamples = func.dataDimension();
if (bSize > totalSamples) {
log.info("WARNING: Total number of samples=" + totalSamples +
" is smaller than requested batch size=" + bSize + "!!!");
bSize = totalSamples;
sayln("Using batch size=" + bSize);
}
if (bSize <= 0) {
log.info("WARNING: Requested batch size=" + bSize + " <= 0 !!!");
bSize = totalSamples;
sayln("Using batch size=" + bSize);
}
}
x = new double[initial.length];
double[] testUpdateCache = null, currentRateCache = null, bCache = null;
sumGradSquare = new double[initial.length];
prevGrad = new double[initial.length];
prevDeltaX = new double[initial.length];
if (useAdaDelta) {
sumDeltaXSquare = new double[initial.length];
if (prior != Prior.NONE && prior != Prior.GAUSSIAN) {
throw new UnsupportedOperationException("useAdaDelta is currently only supported for Prior.NONE or Prior.GAUSSIAN");
}
}
int[][] featureGrouping = null;
if (prior != Prior.LASSO && prior != Prior.NONE) {
testUpdateCache = new double[initial.length];
currentRateCache = new double[initial.length];
}
if (prior != Prior.LASSO && prior != Prior.RIDGE && prior != Prior.GAUSSIAN) {
if (!(f instanceof HasFeatureGrouping)) {
throw new UnsupportedOperationException("prior is specified to be ae-lasso or g-lasso, but function does not support feature grouping");
}
featureGrouping = ((HasFeatureGrouping)f).getFeatureGrouping();
}
if (prior == Prior.sgLASSO) {
bCache = new double[initial.length];
}
System.arraycopy(initial, 0, x, 0, x.length);
int numBatches = 1;
if (f instanceof AbstractStochasticCachingDiffUpdateFunction) {
if (totalSamples > 0)
numBatches = totalSamples / bSize;
}
boolean have_max = (maxIterations > 0 || numPasses > 0);
if (!have_max){
throw new UnsupportedOperationException("No maximum number of iterations has been specified.");
} else{
maxIterations = Math.max(maxIterations, numPasses*numBatches);
}
sayln(" Batch size of: " + bSize);
sayln(" Data dimension of: " + totalSamples );
sayln(" Batches per pass through data: " + numBatches );
sayln(" Number of passes is = " + numPasses);
sayln(" Max iterations is = " + maxIterations);
//!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
//!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// Loop
//!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
//!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Timing total = new Timing();
Timing current = new Timing();
total.start();
current.start();
int iters = 0;
double gValue = 0;
double wValue = 0;
double currentRate = 0, testUpdate = 0, realUpdate = 0;
List<Double> values = null;
double oldObjVal = 0;
for (int pass = 0; pass < numPasses; pass++) {
boolean doEval = (pass > 0 && evaluateIters > 0 && pass % evaluateIters == 0);
double evalScore = Double.NEGATIVE_INFINITY;
if (doEval) {
evalScore = doEvaluation(x);
if (useEvalImprovement && !toContinue(x, evalScore))
break;
}
// TODO: currently objVal is only updated for GAUSSIAN prior
// when other priors are used, objVal only reflects the un-regularized obj value
double objVal = Double.NEGATIVE_INFINITY;
double objDelta = Double.NEGATIVE_INFINITY;
say("Iter: " + iters + " pass " + pass + " batch 1 ... ");
int numOfNonZero = 0, numOfNonZeroGroup = 0;
String gSizeStr = "";
for (int batch = 0; batch < numBatches; batch++) {
iters++;
//Get the next gradients
// log.info("getting gradients");
double[] gradients = null;
if (f instanceof AbstractStochasticCachingDiffUpdateFunction) {
AbstractStochasticCachingDiffUpdateFunction func = (AbstractStochasticCachingDiffUpdateFunction) f;
if (bSize == totalSamples) {
objVal = func.valueAt(x);
gradients = func.getDerivative();
objDelta = objVal-oldObjVal;
oldObjVal = objVal;
if (values == null)
values = new ArrayList<>();
values.add(objVal);
} else {
func.calculateStochasticGradient(x, bSize);
gradients = func.getDerivative();
}
} else if (f instanceof AbstractCachingDiffFunction) {
AbstractCachingDiffFunction func = (AbstractCachingDiffFunction) f;
gradients = func.derivativeAt(x);
}
// log.info("applying regularization");
if (prior == Prior.NONE || prior == Prior.GAUSSIAN) { // Gaussian prior is also handled in objective
for (int index = 0; index < x.length; index++) {
gValue = gradients[index];
currentRate = computeLearningRate(index, gValue);
// arrive at x(t+1/2)
wValue = x[index];
testUpdate = wValue - (currentRate * gValue);
realUpdate = testUpdate;
updateX(x, index, realUpdate);
// x[index] = testUpdate;
}
} else if (prior == Prior.LASSO || prior == Prior.RIDGE) {
double testUpdateSquaredSum = 0;
Set<Integer> paramRange = null;
if (f instanceof HasRegularizerParamRange) {
paramRange = ((HasRegularizerParamRange)f).getRegularizerParamRange(x);
} else {
paramRange = new HashSet<>();
for (int i = 0; i < x.length; i++)
paramRange.add(i);
}
for (int index : paramRange) {
gValue = gradients[index];
currentRate = computeLearningRate(index, gValue);
// arrive at x(t+1/2)
wValue = x[index];
testUpdate = wValue - (currentRate * gValue);
double currentLambda = currentRate * lambda;
// apply FOBOS
if (prior == Prior.LASSO) {
realUpdate = Math.signum(testUpdate) * pospart(Math.abs(testUpdate) - currentLambda);
updateX(x, index, realUpdate);
if (realUpdate != 0)
numOfNonZero++;
} else if (prior == Prior.RIDGE) {
testUpdateSquaredSum += testUpdate*testUpdate;
testUpdateCache[index] = testUpdate;
currentRateCache[index] = currentRate;
// } else if (prior == Prior.GAUSSIAN) { // GAUSSIAN prior is assumed to be handled in the objective directly
// realUpdate = testUpdate / (1 + currentLambda);
// updateX(x, index, realUpdate);
// // update objVal
// objVal += currentLambda * wValue * wValue;
}
}
if (prior == Prior.RIDGE) {
double testUpdateNorm = Math.sqrt(testUpdateSquaredSum);
for (int index = 0 ; index < testUpdateCache.length; index++) {
realUpdate = testUpdateCache[index] * pospart( 1 - currentRateCache[index] * lambda / testUpdateNorm );
updateX(x, index, realUpdate);
if (realUpdate != 0)
numOfNonZero++;
}
}
} else {
// log.info("featureGroup.length: " + featureGrouping.length);
for (int[] gFeatureIndices : featureGrouping) {
// if (gIndex % 100 == 0) log.info(gIndex+" ");
double testUpdateSquaredSum = 0;
double testUpdateAbsSum = 0;
double M = gFeatureIndices.length;
double dm = Math.log(M);
for (int index : gFeatureIndices) {
gValue = gradients[index];
currentRate = computeLearningRate(index, gValue);
// arrive at x(t+1/2)
wValue = x[index];
testUpdate = wValue - (currentRate * gValue);
testUpdateSquaredSum += testUpdate * testUpdate;
testUpdateAbsSum += Math.abs(testUpdate);
testUpdateCache[index] = testUpdate;
currentRateCache[index] = currentRate;
}
if (prior == Prior.gLASSO) {
double testUpdateNorm = Math.sqrt(testUpdateSquaredSum);
boolean groupHasNonZero = false;
for (int index : gFeatureIndices) {
realUpdate = testUpdateCache[index] * pospart(1 - currentRateCache[index] * lambda * dm / testUpdateNorm);
updateX(x, index, realUpdate);
if (realUpdate != 0) {
numOfNonZero++;
groupHasNonZero = true;
}
}
if (groupHasNonZero)
numOfNonZeroGroup++;
} else if (prior == Prior.aeLASSO) {
int nonZeroCount = 0;
boolean groupHasNonZero = false;
for (int index : gFeatureIndices) {
double tau = currentRateCache[index] * lambda / (1 + currentRateCache[index] * lambda * M) * testUpdateAbsSum;
realUpdate = Math.signum(testUpdateCache[index]) * pospart(Math.abs(testUpdateCache[index]) - tau);
updateX(x, index, realUpdate);
if (realUpdate != 0) {
numOfNonZero++;
nonZeroCount++;
groupHasNonZero = true;
}
}
if (groupHasNonZero)
numOfNonZeroGroup++;
// gSizeStr += nonZeroCount+",";
} else if (prior == Prior.sgLASSO) {
double bSquaredSum = 0, b = 0;
for (int index : gFeatureIndices) {
b = Math.signum(testUpdateCache[index]) * pospart(Math.abs(testUpdateCache[index]) -
currentRateCache[index] * alpha * lambda);
bCache[index] = b;
bSquaredSum += b * b;
}
double bNorm = Math.sqrt(bSquaredSum);
int nonZeroCount = 0;
boolean groupHasNonZero = false;
for (int index : gFeatureIndices) {
realUpdate = bCache[index] * pospart(1 - currentRateCache[index] * (1.0 - alpha) * lambda * dm / bNorm);
updateX(x, index, realUpdate);
if (realUpdate != 0) {
numOfNonZero++;
nonZeroCount++;
groupHasNonZero = true;
}
}
if (groupHasNonZero) {
numOfNonZeroGroup++;
// gSizeStr += nonZeroCount+",";
}
}
}
// log.info();
}
// update gradient and lastX
for (int index = 0; index < x.length; index++) {
prevGrad[index] = gradients[index];
}
// if (hessSampleSize > 0) {
// approxHessian();
// }
}
try {
ArrayMath.assertFinite(x,"x");
} catch (ArrayMath.InvalidElementException e) {
log.info(e.toString());
for(int i=0;i<x.length;i++){ x[i]=Double.NaN; }
break;
}
sayln(String.valueOf(numBatches)+", n0-fCount:" + numOfNonZero + ((prior != Prior.LASSO && prior != Prior.RIDGE)? ", n0-gCount:"+numOfNonZeroGroup : "") + ((evalScore != Double.NEGATIVE_INFINITY) ? ", evalScore:"+evalScore : "") + (objVal != Double.NEGATIVE_INFINITY ? ", obj_val:" + nf.format(objVal) + ", obj_delta:" + objDelta : "") );
if (values != null && useAvgImprovement && iters > 5) {
int size = values.size();
double previousVal = (size >= 10 ? values.get(size - 10) : values.get(0));
double averageImprovement = (previousVal - objVal) / (size >= 10 ? 10 : size);
if (Math.abs(averageImprovement / objVal) < TOL) {
sayln("Online Optmization completed, due to average improvement: | newest_val - previous_val | / |newestVal| < TOL ");
break;
}
}
if (iters >= maxIterations) {
sayln("Online Optimization complete. Stopped after max iterations");
break;
}
if (total.report() >= maxTime){
sayln("Online Optimization complete. Stopped after max time");
break;
}
}
if (evaluateIters > 0) {
// do final evaluation
double evalScore = (useEvalImprovement ? doEvaluation(xBest) : doEvaluation(x));
sayln("final evalScore is: " + evalScore);
}
sayln("Completed in: " + Timing.toSecondsString(total.report()) + " s");
return (useEvalImprovement ? xBest : x);
}
protected void sayln(String s) {
if (!quiet) {
log.info(s);
}
}
protected void say(String s) {
if (!quiet) {
log.info(s);
}
}
}