package de.jungblut.online.regression;
import com.google.common.base.Preconditions;
import de.jungblut.classification.AbstractPredictor;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.activation.ActivationFunction;
import de.jungblut.math.dense.SingleEntryDoubleVector;
/**
* Classifier for regression model. Takes a model or the atomic parts of it and
* predicts the outcome for a given feature.
*
* @author thomas.jungblut
*
*/
public class RegressionClassifier extends AbstractPredictor {
private final RegressionModel model;
public RegressionClassifier(RegressionModel model) {
this.model = Preconditions.checkNotNull(model, "model");
}
public RegressionClassifier(DoubleVector weights, ActivationFunction function) {
this(new RegressionModel(weights, function));
}
@Override
public DoubleVector predict(DoubleVector feature) {
Preconditions.checkArgument(feature.getDimension() == model.getWeights()
.getDimension(),
"feature dimension must match model weight dimension! Feature: "
+ feature.getDimension() + " != Model: "
+ model.getWeights().getDimension());
double result = model.getActivationFunction().apply(
feature.dot(model.getWeights()));
return new SingleEntryDoubleVector(result);
}
}