package de.jungblut.online.regularization;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.minimize.CostGradientTuple;
import de.jungblut.online.ml.FeatureOutcomePair;
// TODO split this into three interfaces
public interface WeightUpdater {
/**
* Computes a pre-prediction time weight update.
*
* @param featureOutcome the current feature outcome pair
* @param theta the weights to augment.
* @param learningRate the learning rate.
* @param iteration the number of the current iteration.
* @return a changed weight vector or just plainly theta.
*/
public DoubleVector prePredictionWeightUpdate(
FeatureOutcomePair featureOutcome, DoubleVector theta,
double learningRate, long iteration);
/**
* Computes the update for the given weights.
*
* @param theta the old weights.
* @param gradient the pre-computed gradient from the loss function.
* @param learningRate the learning rate.
* @param iteration the number of the current iteration.
* @param cost the computed cost for this gradient update.
* @return the already updated weights for a particular updated gradient.
*/
public CostWeightTuple computeNewWeights(DoubleVector theta,
DoubleVector gradient, double learningRate, long iteration, double cost);
/**
* Updates the gradient.
*
* @param theta the old weights.
* @param gradient the pre-computed gradient from the loss function.
* @param learningRate the learning rate.
* @param iteration the number of the current iteration.
* @param cost the computed cost for this gradient update.
* @return the gradient vector that should be substracted from the weights and
* the updated cost.
*/
public CostGradientTuple updateGradient(DoubleVector theta,
DoubleVector gradient, double learningRate, long iteration, double cost);
}