package de.jungblut.online.bayes; import java.util.Iterator; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; import java.util.stream.Stream; import org.apache.commons.math3.util.FastMath; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import de.jungblut.math.DoubleMatrix; import de.jungblut.math.DoubleVector; import de.jungblut.math.DoubleVector.DoubleVectorElement; import de.jungblut.math.dense.DenseDoubleVector; import de.jungblut.math.sparse.SparseDoubleRowMatrix; import de.jungblut.online.ml.AbstractOnlineLearner; import de.jungblut.online.ml.FeatureOutcomePair; /** * Multinomial naive bayes learner. This class now contains a sparse internal * representations of the "feature given class" probabilities. Thus it can be * scaled to very large text corpora and large numbers of classes easily. * * The internal accesses are thread-safe, so a parallel stream can be used. * * @author thomas.jungblut * */ public class NaiveBayesLearner extends AbstractOnlineLearner<BayesianProbabilityModel> { private static final Logger LOG = LogManager .getLogger(NaiveBayesLearner.class); private DoubleMatrix probabilityMatrix; private DoubleVector classPriorProbability; private boolean verbose; /** * Default constructor to construct this classifier. */ public NaiveBayesLearner() { } /** * Pass true if this classifier should output some progress information to the * logger. */ public NaiveBayesLearner(boolean verbose) { this.verbose = verbose; } @Override public BayesianProbabilityModel train( Supplier<Stream<FeatureOutcomePair>> streamSupplier) { init(streamSupplier); Stream<FeatureOutcomePair> stream = streamSupplier.get(); // sparse row representations, so every class has the features as a hashset // of values. This gives good compression for many class problems. probabilityMatrix = new SparseDoubleRowMatrix(numOutcomeClasses, featureDimension); int[] tokenPerClass = new int[numOutcomeClasses]; int[] numDocumentsPerClass = new int[numOutcomeClasses]; // observe the probabilities AtomicInteger numDocumentsSeen = new AtomicInteger(0); stream.forEach((pair) -> { observe(pair.getFeature(), pair.getOutcome(), numOutcomeClasses, tokenPerClass, numDocumentsPerClass); numDocumentsSeen.incrementAndGet(); }); // know we know the token distribution per class, we can calculate the // probability. It is intended for them to be negative in some cases for (int row = 0; row < numOutcomeClasses; row++) { // we can quite efficiently iterate over the non-zero row vectors now DoubleVector rowVector = probabilityMatrix.getRowVector(row); // don't care about not occuring words, we honor them with a very small // probability later on when predicting, here we save a lot space. Iterator<DoubleVectorElement> iterateNonZero = rowVector.iterateNonZero(); double normalizer = FastMath.log(tokenPerClass[row] + probabilityMatrix.getColumnCount() - 1); while (iterateNonZero.hasNext()) { DoubleVectorElement next = iterateNonZero.next(); double currentWordCount = next.getValue(); double logProbability = FastMath.log(currentWordCount) - normalizer; probabilityMatrix.set(row, next.getIndex(), logProbability); } if (verbose) { LOG.info("Computed " + row + " / " + numOutcomeClasses + "!"); } } classPriorProbability = new DenseDoubleVector(numOutcomeClasses); for (int i = 0; i < numOutcomeClasses; i++) { double prior = FastMath.log(numDocumentsPerClass[i]) - FastMath.log(numDocumentsSeen.get()); classPriorProbability.set(i, prior); } return new BayesianProbabilityModel(probabilityMatrix, classPriorProbability); } private void observe(DoubleVector document, DoubleVector outcome, int numDistinctClasses, int[] tokenPerClass, int[] numDocumentsPerClass) { int predictedClass = outcome.maxIndex(); if (numDistinctClasses == 2) { predictedClass = (int) outcome.get(0); } synchronized (probabilityMatrix) { tokenPerClass[predictedClass] += document.getLength(); numDocumentsPerClass[predictedClass]++; } Iterator<DoubleVectorElement> iterateNonZero = document.iterateNonZero(); while (iterateNonZero.hasNext()) { DoubleVectorElement next = iterateNonZero.next(); // TODO this is a very granular lock that is acquired very often: // can this high contention be improved, e.g. by writing a temporary // vector and then just merging updates? synchronized (probabilityMatrix) { double currentCount = probabilityMatrix.get(predictedClass, next.getIndex()); probabilityMatrix.set(predictedClass, next.getIndex(), currentCount + next.getValue()); } } } }