package de.jungblut.online.regularization; import org.junit.Assert; import org.junit.Test; import de.jungblut.math.DoubleVector; import de.jungblut.math.dense.DenseDoubleVector; public class TestL2Regularizer { @Test public void testGradientUpdate() { WeightUpdater updater = new L2Regularizer(1d); DoubleVector theta = new DenseDoubleVector(new double[] { 1d, 1d, 1d }); DoubleVector grad = new DenseDoubleVector(new double[] { 1d, 1d, 1d }); double learningRate = 0.1d; CostWeightTuple update = updater.computeNewWeights(theta, grad, learningRate, 1, 1d); double[] expected = new double[] { 0.9, 0.8, 0.8 }; Assert.assertArrayEquals(expected, update.getWeight().toArray(), 1e-8); Assert.assertEquals(2d, update.getCost(), 0d); } @Test public void testNoOpUpdate() { WeightUpdater updater = new L2Regularizer(0d); DoubleVector theta = new DenseDoubleVector(new double[] { 1d, 1d, 1d }); DoubleVector grad = new DenseDoubleVector(new double[] { 1d, 1d, 1d }); double learningRate = 0.1d; CostWeightTuple update = updater.computeNewWeights(theta, grad, learningRate, 1, 1d); Assert.assertArrayEquals(theta.subtract(grad.multiply(learningRate)) .toArray(), update.getWeight().toArray(), 1e-8); Assert.assertEquals(1d, update.getCost(), 0d); } }