package de.jungblut.online.regularization; import org.junit.Assert; import org.junit.Test; import de.jungblut.math.DoubleVector; import de.jungblut.math.dense.DenseDoubleVector; public class TestGradientDescentUpdater { @Test public void testGradientUpdate() { GradientDescentUpdater updater = new GradientDescentUpdater(); DoubleVector theta = new DenseDoubleVector(new double[] { 1d, 1d, 1d }); DoubleVector grad = new DenseDoubleVector(new double[] { 1d, 1d, 1d }); double learningRate = 0.1d; CostWeightTuple update = updater.computeNewWeights(theta, grad, learningRate, 1, 1d); Assert.assertArrayEquals(theta.subtract(grad.multiply(learningRate)) .toArray(), update.getWeight().toArray(), 1e-8); Assert.assertEquals(1d, update.getCost(), 0d); } }