package de.jungblut.online.regularization;
import java.util.Iterator;
import org.apache.commons.math3.util.FastMath;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.DoubleVector.DoubleVectorElement;
import de.jungblut.math.minimize.CostGradientTuple;
import de.jungblut.online.ml.FeatureOutcomePair;
/**
* Based on the paper:
* http://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf
*
* @author thomas.jungblut
*
*/
public final class AdaptiveFTRLRegularizer implements WeightUpdater {
private final double beta;
private final double l1;
private final double l2;
private DoubleVector squaredPreviousGradient; // n in the paper
private DoubleVector perCoordinateWeights; // z in the paper
/**
* Creates a new AdaptiveFTRLRegularizer.
*
* @param beta the smoothing parameter for the learning rate.
* @param l1 the l1 regularization.
* @param l2 the l2 regularization.
*/
public AdaptiveFTRLRegularizer(double beta, double l1, double l2) {
this.beta = beta;
this.l1 = l1;
this.l2 = l2;
}
@Override
public DoubleVector prePredictionWeightUpdate(
FeatureOutcomePair featureOutcome, DoubleVector theta,
double learningRate, long iteration) {
if (squaredPreviousGradient == null) {
// initialize zeroed vectors of the same type as the weights
squaredPreviousGradient = theta.deepCopy().multiply(0);
perCoordinateWeights = theta.deepCopy().multiply(0);
}
Iterator<DoubleVectorElement> iterateNonZero = featureOutcome.getFeature()
.iterateNonZero();
while (iterateNonZero.hasNext()) {
DoubleVectorElement next = iterateNonZero.next();
double gradientValue = next.getValue();
int index = next.getIndex();
double zi = perCoordinateWeights.get(index);
double ni = squaredPreviousGradient.get(index);
if (FastMath.abs(zi) <= l1) {
theta.set(index, 0);
} else {
double value = -1d / (((beta + FastMath.sqrt(ni)) / learningRate) + l2);
value = value * (zi - FastMath.signum(gradientValue) * l1);
theta.set(index, value);
}
}
return theta;
}
@Override
public CostWeightTuple computeNewWeights(DoubleVector theta,
DoubleVector gradient, double learningRate, long iteration, double cost) {
Iterator<DoubleVectorElement> iterateNonZero = gradient.iterateNonZero();
while (iterateNonZero.hasNext()) {
DoubleVectorElement next = iterateNonZero.next();
double gradientValue = next.getValue();
int index = next.getIndex();
double zi = perCoordinateWeights.get(index);
double ni = squaredPreviousGradient.get(index);
// update our cached copies
double sigma = (FastMath.sqrt(ni + gradientValue * gradientValue) - FastMath
.sqrt(ni)) / learningRate;
perCoordinateWeights.set(index,
zi + gradientValue - sigma * theta.get(index));
squaredPreviousGradient.set(index, ni + gradientValue * gradientValue);
}
return new CostWeightTuple(cost, theta);
}
@Override
public CostGradientTuple updateGradient(DoubleVector theta,
DoubleVector gradient, double learningRate, long iteration, double cost) {
return null;
}
}