package de.jungblut.online.regression;
import com.google.common.base.Preconditions;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.activation.ActivationFunction;
import de.jungblut.math.dense.SingleEntryDoubleVector;
import de.jungblut.math.loss.LossFunction;
import de.jungblut.math.minimize.CostGradientTuple;
import de.jungblut.online.minimizer.StochasticMinimizer;
import de.jungblut.online.ml.AbstractMinimizingOnlineLearner;
import de.jungblut.online.ml.FeatureOutcomePair;
/**
* A regression learner that learns weights on a stream, given an optimization
* objective (e.g. log loss). This learner outputs a RegressionModel that can be
* used in a RegressionClassifier.
*
* @author thomas.jungblut
*
*/
public class RegressionLearner extends
AbstractMinimizingOnlineLearner<RegressionModel> {
private final ActivationFunction activationFunction;
private final LossFunction lossFunction;
public RegressionLearner(StochasticMinimizer minimizer,
ActivationFunction activationFunction, LossFunction lossFunction) {
super(minimizer);
this.activationFunction = Preconditions.checkNotNull(activationFunction,
"activation function");
this.lossFunction = Preconditions.checkNotNull(lossFunction,
"loss function");
}
@Override
protected CostGradientTuple observeExample(FeatureOutcomePair next,
DoubleVector weights) {
DoubleVector hypothesis = new SingleEntryDoubleVector(
activationFunction.apply(next.getFeature().dot(weights)));
double cost = lossFunction.calculateLoss(next.getOutcome(), hypothesis);
DoubleVector gradient = lossFunction.calculateGradient(next.getFeature(),
next.getOutcome(), hypothesis);
return new CostGradientTuple(cost, gradient);
}
@Override
public RegressionModel createModel(DoubleVector weights) {
return new RegressionModel(weights, activationFunction);
}
}