package de.jungblut.online.regularization;
import java.util.Iterator;
import org.apache.commons.math3.util.FastMath;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.DoubleVector.DoubleVectorElement;
/**
* Ported to "real" Java from Spark's mllib
* org.apache.spark.mllib.optimization.Updater.
*
* L1 regularizer: R(w) = ||w||_1.
*
* Uses a step-size decreasing with the square root of the number of iterations.
*
* Instead of subgradient of the regularizer, the proximal operator for the L1
* regularization is applied after the gradient step. This is known to result in
* better sparsity of the intermediate solution.
*/
public final class L1Regularizer extends GradientDescentUpdater {
private final double tol;
private final double l1;
public L1Regularizer(double l1) {
this.l1 = l1;
this.tol = l1;
}
public L1Regularizer(double l1, double tol) {
this.l1 = l1;
this.tol = tol;
}
@Override
public CostWeightTuple computeNewWeights(DoubleVector theta,
DoubleVector gradient, double learningRate, long iteration, double cost) {
if (l1 == 0d) {
// do simple gradient descent step in this case
return super.computeNewWeights(theta, gradient, learningRate, iteration,
cost);
}
DoubleVector newWeights = theta.subtract(gradient.multiply(learningRate));
double shrinkageVal = l1 * learningRate;
double addedCost = 0d;
if (newWeights.isSparse()) {
DoubleVector deepCopy = newWeights.deepCopy();
Iterator<DoubleVectorElement> iterateNonZero = newWeights
.iterateNonZero();
while (iterateNonZero.hasNext()) {
DoubleVectorElement next = iterateNonZero.next();
if (next.getIndex() > 0) {
addedCost += updateWeight(newWeights, deepCopy, shrinkageVal,
next.getIndex(), next.getValue());
}
}
newWeights = deepCopy;
} else {
for (int i = 1; i < newWeights.getDimension(); i++) {
addedCost += updateWeight(newWeights, newWeights, shrinkageVal, i,
newWeights.get(i));
}
}
cost += addedCost * l1;
return new CostWeightTuple(cost, newWeights);
}
private double updateWeight(DoubleVector newWeights,
DoubleVector toBeUpdated, double shrinkageVal, int i, double weight) {
double absWeight = FastMath.abs(weight);
double newWeight = FastMath.signum(weight)
* FastMath.max(0.0, absWeight - shrinkageVal);
if (FastMath.abs(newWeight) < tol) {
newWeight = 0;
}
toBeUpdated.set(i, newWeight);
return absWeight;
}
}