package de.jungblut.online.regularization; import de.jungblut.math.DoubleVector; import de.jungblut.math.minimize.CostGradientTuple; import de.jungblut.online.ml.FeatureOutcomePair; public class GradientDescentUpdater implements WeightUpdater { /** * Simplistic gradient descent without regularization. */ @Override public CostWeightTuple computeNewWeights(DoubleVector theta, DoubleVector gradient, double learningRate, long iteration, double cost) { CostGradientTuple gradientTuple = updateGradient(theta, gradient, learningRate, iteration, cost); DoubleVector dampened = gradientTuple.getGradient().multiply(learningRate); DoubleVector newWeights = theta.subtract(dampened); return new CostWeightTuple(gradientTuple.getCost(), newWeights); } @Override public CostGradientTuple updateGradient(DoubleVector theta, DoubleVector gradient, double learningRate, long iteration, double cost) { return new CostGradientTuple(cost, gradient); } @Override public DoubleVector prePredictionWeightUpdate( FeatureOutcomePair featureOutcome, DoubleVector theta, double learningRate, long iteration) { return theta; } }