package de.jungblut.online.regression.multinomial;
import com.google.common.base.Preconditions;
import de.jungblut.classification.AbstractPredictor;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.online.regression.RegressionClassifier;
/**
* Classifier for multinomial regression.
*
* @author thomas.jungblut
*
*/
public class MultinomialRegressionClassifier extends AbstractPredictor {
private final RegressionClassifier[] classifier;
private boolean normalize;
/**
* Constructs a new multinomial regression classifier that does normalization
* over independent predictions.
*
* @param model the trained model.
*/
public MultinomialRegressionClassifier(MultinomialRegressionModel model) {
this(model, true);
}
/**
* Constructs a new multinomial regression classifier that does normalization
* over independent predictions by summing over the predictions and dividing
* each entry.
*
* @param model the trained model.
* @param normalize true for normalizing the output.
*
*/
public MultinomialRegressionClassifier(MultinomialRegressionModel model,
boolean normalize) {
this.normalize = normalize;
Preconditions.checkNotNull(model, "model");
this.classifier = new RegressionClassifier[model.getModels().length];
for (int i = 0; i < model.getModels().length; i++) {
classifier[i] = new RegressionClassifier(model.getModels()[i]);
}
}
@Override
public DoubleVector predict(DoubleVector feature) {
DoubleVector mesh = new DenseDoubleVector(classifier.length);
for (int i = 0; i < classifier.length; i++) {
RegressionClassifier clf = classifier[i];
DoubleVector prediction = clf.predict(feature);
Preconditions.checkArgument(prediction.getDimension() == 1,
"Prediction only works for a single dimensional output! Given "
+ prediction.getDimension());
mesh.set(i, prediction.get(0));
}
if (normalize) {
double sum = mesh.sum();
if (sum != 0d) {
mesh = mesh.divide(sum);
}
}
return mesh;
}
}