package de.jungblut.online.regularization; import java.util.function.Function; import org.junit.Assert; import org.junit.Test; import de.jungblut.math.DoubleVector; import de.jungblut.math.dense.DenseDoubleVector; import de.jungblut.math.sparse.SequentialSparseDoubleVector; import de.jungblut.math.sparse.SparseDoubleVector; public class TestL1Regularizer { @Test public void testGradientUpdate() { WeightUpdater updater = new L1Regularizer(1d, 0d); DoubleVector theta = new DenseDoubleVector(new double[] { 1d, 1d, 1d }); DoubleVector grad = new DenseDoubleVector(new double[] { 1d, 1d, 1d }); double learningRate = 0.1d; CostWeightTuple update = updater.computeNewWeights(theta, grad, learningRate, 1, 1d); double[] expected = new double[] { 0.9, 0.8, 0.8 }; Assert.assertArrayEquals(expected, update.getWeight().toArray(), 1e-8); Assert.assertEquals(2.8d, update.getCost(), 1e-8); } @Test public void testNoOpUpdate() { WeightUpdater updater = new L1Regularizer(0d, 0d); DoubleVector theta = new DenseDoubleVector(new double[] { 1d, 1d, 1d }); DoubleVector grad = new DenseDoubleVector(new double[] { 1d, 1d, 1d }); double learningRate = 0.1d; CostWeightTuple update = updater.computeNewWeights(theta, grad, learningRate, 1, 1d); Assert.assertArrayEquals(theta.subtract(grad.multiply(learningRate)) .toArray(), update.getWeight().toArray(), 1e-8); Assert.assertEquals(1d, update.getCost(), 0d); } @Test public void testToleranceRemovalDense() { baseToleranceRemoval((vec) -> new DenseDoubleVector(vec)); } @Test public void testToleranceRemovalSeqSparse() { baseToleranceRemoval((vec) -> new SequentialSparseDoubleVector(vec)); } @Test public void testToleranceRemovalSparse() { baseToleranceRemoval((vec) -> new SparseDoubleVector(vec)); } public void baseToleranceRemoval( Function<double[], DoubleVector> vectorFactory) { WeightUpdater updater = new L1Regularizer(1d, 0.75); DoubleVector theta = vectorFactory.apply(new double[] { 1d, 1d, 1d }); DoubleVector grad = vectorFactory.apply(new double[] { 1d, 1d, 2d }); double learningRate = 0.1d; CostWeightTuple update = updater.computeNewWeights(theta, grad, learningRate, 1, 1d); double[] expected = new double[] { 0.9, 0.8, 0 }; Assert.assertArrayEquals(expected, update.getWeight().toArray(), 1e-8); Assert.assertEquals(2.7d, update.getCost(), 1e-8); } @Test public void testSparseVectors() { WeightUpdater updater = new L1Regularizer(1d, 0d); DoubleVector theta = new SequentialSparseDoubleVector(new double[] { 1d, 1d, 1d }); DoubleVector grad = new SequentialSparseDoubleVector(new double[] { 1d, 1d, 1d }); double learningRate = 0.1d; CostWeightTuple update = updater.computeNewWeights(theta, grad, learningRate, 1, 1d); double[] expected = new double[] { 0.9, 0.8, 0.8 }; Assert.assertArrayEquals(expected, update.getWeight().toArray(), 1e-8); Assert.assertEquals(2.8d, update.getCost(), 1e-8); } }