package de.jungblut.online.regularization; import org.apache.commons.math3.util.FastMath; import com.google.common.base.Preconditions; import de.jungblut.math.DoubleVector; import de.jungblut.math.minimize.CostGradientTuple; /** * Adam updater, inspired by nd4j. Whitepaper http://arxiv.org/abs/1412.6980 * */ public class AdamUpdater extends GradientDescentUpdater { public static final double MOVING_AVERAGE_DECAY = 0.9; public static final double SQUARED_DECAY = 0.999; public static final double EPS = 1e-8; private final double alpha; private final double movingAvgDecay; private final double squaredDecay; private final double eps; private DoubleVector movingAvg; private DoubleVector squaredGradient; public AdamUpdater(double alpha) { this(alpha, MOVING_AVERAGE_DECAY, SQUARED_DECAY); } public AdamUpdater(double alpha, double movingAvgDecay, double squaredDecay) { this(alpha, movingAvgDecay, squaredDecay, EPS); } public AdamUpdater(double alpha, double movingAvgDecay, double squaredDecay, double epsilon) { Preconditions.checkArgument(movingAvgDecay >= 0 && movingAvgDecay < 1, "movingAvgDecay must be [0, 1)!"); Preconditions.checkArgument(squaredDecay >= 0 && squaredDecay < 1, "squaredDecay must be [0, 1)!"); this.alpha = alpha; this.movingAvgDecay = movingAvgDecay; this.squaredDecay = squaredDecay; this.eps = epsilon; } @Override public CostGradientTuple updateGradient(DoubleVector theta, DoubleVector gradient, double learningRate, long iteration, double cost) { if (movingAvg == null) { // initialize same types with zeros movingAvg = gradient.deepCopy().multiply(0); squaredGradient = gradient.deepCopy().multiply(0); } DoubleVector oneMinusBeta1Grad = gradient.multiply(1d - movingAvgDecay); movingAvg = movingAvg.multiply(movingAvgDecay).add(oneMinusBeta1Grad); DoubleVector oneMinusBeta2GradSquared = gradient.pow(2d).multiply( 1 - squaredDecay); squaredGradient = squaredGradient.multiply(squaredDecay).add( oneMinusBeta2GradSquared); double beta1t = FastMath.pow(movingAvgDecay, iteration); double beta2t = FastMath.pow(squaredDecay, iteration); double alphat = alpha * FastMath.sqrt(1 - beta2t) / (1 - beta1t); if (Double.isNaN(alphat) || alphat == 0.0) { alphat = EPS; } DoubleVector sqrtV = squaredGradient.sqrt().add(eps); gradient = movingAvg.multiply(alphat).divide(sqrtV); return new CostGradientTuple(cost, gradient); } }