package de.jungblut.online.bayes;
import java.util.Iterator;
import org.apache.commons.math3.util.FastMath;
import com.google.common.base.Preconditions;
import de.jungblut.classification.AbstractPredictor;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.DoubleVector.DoubleVectorElement;
import de.jungblut.math.dense.DenseDoubleVector;
public class BayesianClassifier extends AbstractPredictor {
private static final double LOW_PROBABILITY = FastMath.log(1e-8);
private final BayesianProbabilityModel model;
public BayesianClassifier(BayesianProbabilityModel model) {
super();
this.model = Preconditions.checkNotNull(model, "model");
}
@Override
public DoubleVector predict(DoubleVector features) {
return getProbabilityDistribution(features);
}
private double getProbabilityForClass(DoubleVector document, int classIndex) {
double probabilitySum = 0.0d;
Iterator<DoubleVectorElement> iterateNonZero = document.iterateNonZero();
while (iterateNonZero.hasNext()) {
DoubleVectorElement next = iterateNonZero.next();
double wordCount = next.getValue();
double probabilityOfToken = model.getProbabilityMatrix().get(classIndex,
next.getIndex());
if (probabilityOfToken == 0d) {
probabilityOfToken = LOW_PROBABILITY;
}
probabilitySum += (wordCount * probabilityOfToken);
}
return probabilitySum;
}
private DenseDoubleVector getProbabilityDistribution(DoubleVector document) {
int numClasses = model.getClassPriorProbability().getLength();
DenseDoubleVector distribution = new DenseDoubleVector(numClasses);
// loop through all classes and get the max probable one
for (int i = 0; i < numClasses; i++) {
double probability = getProbabilityForClass(document, i);
distribution.set(i, probability);
}
double maxProbability = distribution.max();
double probabilitySum = 0.0d;
// we normalize it back
for (int i = 0; i < numClasses; i++) {
double probability = distribution.get(i);
double normalizedProbability = FastMath.exp(probability - maxProbability
+ model.getClassPriorProbability().get(i));
distribution.set(i, normalizedProbability);
probabilitySum += normalizedProbability;
}
// since the sum is sometimes not 1, we need to divide by the sum
distribution = (DenseDoubleVector) distribution.divide(probabilitySum);
return distribution;
}
}