package de.jungblut.online.bayes;
import static org.junit.Assert.assertEquals;
import java.util.Arrays;
import java.util.stream.IntStream;
import org.apache.commons.math3.util.FastMath;
import org.junit.Test;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.sparse.SparseDoubleVector;
import de.jungblut.online.ml.FeatureOutcomePair;
public class TestNaiveBayesLearner {
@Test
public void testSimpleNaiveBayes() {
BayesianProbabilityModel model = getTrainedModel();
checkModel(model);
}
public static void checkModel(BayesianProbabilityModel model) {
BayesianClassifier classifier = new BayesianClassifier(model);
DoubleVector classProbability = model.getClassPriorProbability();
assertEquals(FastMath.log(2d / 5d), classProbability.get(0), 0.01d);
assertEquals(FastMath.log(3d / 5d), classProbability.get(1), 0.01d);
DoubleMatrix mat = model.getProbabilityMatrix();
double[] realFirstRow = new double[] { 0.0, 0.0, -2.1972245773362196,
-1.5040773967762742, -1.5040773967762742 };
double[] realSecondRow = new double[] { -0.9808292530117262,
-2.0794415416798357, 0.0, 0.0, 0.0 };
double[] firstRow = mat.getRowVector(0).toArray();
assertEquals(realFirstRow.length, firstRow.length);
for (int i = 0; i < firstRow.length; i++) {
assertEquals("" + Arrays.toString(firstRow), realFirstRow[i],
firstRow[i], 0.05d);
}
double[] secondRow = mat.getRowVector(1).toArray();
assertEquals(realSecondRow.length, secondRow.length);
for (int i = 0; i < firstRow.length; i++) {
assertEquals("" + Arrays.toString(secondRow), realSecondRow[i],
secondRow[i], 0.05d);
}
DoubleVector claz = classifier.predict(new DenseDoubleVector(new double[] {
1, 0, 0, 0, 0 }));
assertEquals("" + claz, 0, claz.get(0), 0.05d);
assertEquals("" + claz, 1, claz.get(1), 0.05d);
claz = classifier.predict(new DenseDoubleVector(new double[] { 0, 0, 0, 1,
1 }));
assertEquals("" + claz, 1, claz.get(0), 0.05d);
assertEquals("" + claz, 0, claz.get(1), 0.05d);
}
public static BayesianProbabilityModel getTrainedModel() {
NaiveBayesLearner learner = new NaiveBayesLearner();
DoubleVector[] features = new DoubleVector[] {
new SparseDoubleVector(new double[] { 1, 0, 0, 0, 0 }),
new SparseDoubleVector(new double[] { 1, 0, 0, 0, 0 }),
new SparseDoubleVector(new double[] { 1, 1, 0, 0, 0 }),
new SparseDoubleVector(new double[] { 0, 0, 1, 1, 1 }),
new SparseDoubleVector(new double[] { 0, 0, 0, 1, 1 }), };
DenseDoubleVector[] outcome = new DenseDoubleVector[] {
new DenseDoubleVector(new double[] { 1 }),
new DenseDoubleVector(new double[] { 1 }),
new DenseDoubleVector(new double[] { 1 }),
new DenseDoubleVector(new double[] { 0 }),
new DenseDoubleVector(new double[] { 0 }), };
return learner.train(() -> IntStream.range(0, features.length).mapToObj(
(i) -> new FeatureOutcomePair(features[i], outcome[i])));
}
}