// Stanford Classifier - a multiclass maxent classifier // NaiveBayesClassifier // Copyright (c) 2003-2007 The Board of Trustees of // The Leland Stanford Junior University. All Rights Reserved. // // This program is free software; you can redistribute it and/or // modify it under the terms of the GNU General Public License // as published by the Free Software Foundation; either version 2 // of the License, or (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with this program; if not, write to the Free Software // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. // // For more information, bug reports, fixes, contact: // Christopher Manning // Dept of Computer Science, Gates 1A // Stanford CA 94305-9010 // USA // Support/Questions: java-nlp-user@lists.stanford.edu // Licensing: java-nlp-support@lists.stanford.edu // http://www-nlp.stanford.edu/software/classifier.shtml package edu.stanford.nlp.classify; import edu.stanford.nlp.ling.Datum; import edu.stanford.nlp.ling.RVFDatum; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.stats.Counters; import edu.stanford.nlp.util.Pair; import java.io.PrintStream; import java.util.Iterator; import java.util.Set; import java.util.Collection; import edu.stanford.nlp.util.logging.Redwood; /** * A Naive Bayes classifier with a fixed number of features. * The features are assumed to have integer values even though RVFDatum will return doubles. * * @author Kristina Toutanova (kristina@cs.stanford.edu) * @author Sarah Spikes (sdspikes@cs.stanford.edu) - Templatization. Not sure what the weights counter * is supposed to hold; given the weights function it seems to hold {@code Pair<Pair<L, F>, Object>} * but this seems like a strange thing to expect to be passed in. */ public class NaiveBayesClassifier<L, F> implements Classifier<L, F>, RVFClassifier<L, F> { private static final long serialVersionUID = 1544820342684024068L; private Counter<Pair<Pair<L, F>, Number>> weights; //the keys will be class and feature and value private Counter<L> priors; private Set<F> features; // we need all features to add the weights for zero-valued ones private boolean addZeroValued; // whether to add features as having value 0 if they are not in Datum/RFVDatum private Counter<L> priorZero; //if we need to add the zeros, pre-compute the weight for all zeros for each class private Set<L> labels; private final Integer zero = Integer.valueOf(0); final static Redwood.RedwoodChannels logger = Redwood.channels(NaiveBayesClassifier.class); public Collection<L> labels() { return labels; } public L classOf(RVFDatum<L, F> example) { Counter<L> scores = scoresOf(example); return Counters.argmax(scores); } public ClassicCounter<L> scoresOf(RVFDatum<L, F> example) { ClassicCounter<L> scores = new ClassicCounter<>(); Counters.addInPlace(scores, priors); if (addZeroValued) { Counters.addInPlace(scores, priorZero); } for (L l : labels) { double score = 0.0; Counter<F> features = example.asFeaturesCounter(); for (F f : features.keySet()) { int value = (int) features.getCount(f); score += weight(l, f, Integer.valueOf(value)); if (addZeroValued) { score -= weight(l, f, zero); } } scores.incrementCount(l, score); } return scores; } public L classOf(Datum<L, F> example) { RVFDatum<L, F> rvf = new RVFDatum<>(example); return classOf(rvf); } public ClassicCounter<L> scoresOf(Datum<L, F> example) { RVFDatum<L, F> rvf = new RVFDatum<>(example); return scoresOf(rvf); } public NaiveBayesClassifier(Counter<Pair<Pair<L, F>, Number>> weights, Counter<L> priors, Set<L> labels, Set<F> features, boolean addZero) { this.weights = weights; this.features = features; this.priors = priors; this.labels = labels; addZeroValued = addZero; if (addZeroValued) { initZeros(); } } public float accuracy(Iterator<RVFDatum<L, F>> exampleIterator) { int correct = 0; int total = 0; for (; exampleIterator.hasNext();) { RVFDatum<L, F> next = exampleIterator.next(); L guess = classOf(next); if (guess.equals(next.label())) { correct++; } total++; } logger.info("correct " + correct + " out of " + total); return correct / (float) total; } public void print(PrintStream pw) { pw.println("priors "); pw.println(priors.toString()); pw.println("weights "); pw.println(weights.toString()); } public void print() { print(System.out); } private double weight(L label, F feature, Number val) { Pair<Pair<L, F>, Number> p = new Pair<>(new Pair<>(label, feature), val); double v = weights.getCount(p); return v; } public NaiveBayesClassifier(Counter<Pair<Pair<L, F>, Number>> weights, Counter<L> priors, Set<L> labels) { this(weights, priors, labels, null, false); } /** * In case the features for which there is a value 0 in an example need to have their coefficients multiplied in, * we need to pre-compute the addition * priorZero(l)=sum_{features} wt(l,feat=0) */ private void initZeros() { priorZero = new ClassicCounter<>(); for (L label : labels) { double score = 0; for (F feature : features) { score += weight(label, feature, zero); } priorZero.setCount(label, score); } } }