package edu.berkeley.nlp.classify;
import java.util.ArrayList;
import java.util.List;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.CounterMap;
public class NaiveBayesClassifier<I,F,L> implements ProbabilisticClassifier<I, L> {
private CounterMap<L,F> featureProbs ;
private Counter<F> backoffProbs ;
private Counter<L> labelProbs ;
private FeatureExtractor<I, F> featureExtractor;
private double alpha = 0.1;
public static class Factory<I,F,L> implements ProbabilisticClassifierFactory<I, L> {
private FeatureExtractor<I, F> featureExtractor;
public Factory(FeatureExtractor<I, F> featureExtractor) {
this.featureExtractor = featureExtractor;
}
public ProbabilisticClassifier<I, L> trainClassifier(List<LabeledInstance<I, L>> trainingData) {
CounterMap<L, F> featureProbs = new CounterMap<L, F>();
Counter<F> backoffProbs = new Counter<F>();
Counter<L> labelProbs = new Counter<L>();
for (LabeledInstance<I, L> instance: trainingData) {
L label = instance.getLabel();
labelProbs.incrementCount(label, 1.0);
I inst = instance.getInput();
Counter<F> featCounts = featureExtractor.extractFeatures(inst);
for (F feat: featCounts.keySet()) {
double count = featCounts.getCount(feat);
backoffProbs.incrementCount(feat, count);
featureProbs.incrementCount(label, feat, count);
}
}
featureProbs.normalize();
labelProbs.normalize();
backoffProbs.normalize();
return new NaiveBayesClassifier<I, F, L>(featureProbs, backoffProbs, labelProbs, featureExtractor);
}
}
public Counter<L> getProbabilities(I instance) {
Counter<L> posteriors = new Counter<L>();
List<Double> logPosteriorsUnnormed = new ArrayList<Double>();
for (L label: labelProbs.keySet()) {
double logPrior = Math.log(labelProbs.getCount(label));
double logPosteriorUnnorm = logPrior;
Counter<F> featCounts =featureExtractor.extractFeatures(instance);
for (F feat: featCounts.keySet()) {
double count = featCounts.getCount(feat);
logPosteriorUnnorm += count * Math.log( getFeatureProb(feat, label) );
}
logPosteriorsUnnormed.add(logPosteriorUnnorm);
posteriors.setCount(label, logPosteriorUnnorm);
}
double logPosteriorNorm = SloppyMath.logAdd(logPosteriorsUnnormed);
for (L label: labelProbs.keySet()) {
double logPosteriorUnnorm = posteriors.getCount(label);
double logPosterior = logPosteriorUnnorm - logPosteriorNorm;
double posterior = Math.exp(logPosterior);
posteriors.setCount(label, posterior);
}
// TODO Auto-generated method stub
return posteriors;
}
private double getFeatureProb(F feat, L label) {
double mleProb = featureProbs.getCount(label, feat);
double backoffProb = backoffProbs.getCount(feat);
return (1-alpha) * mleProb + alpha * backoffProb;
}
public L getLabel(I instance) {
// TODO Auto-generated method stub
return getProbabilities(instance).argMax();
}
public NaiveBayesClassifier(CounterMap<L, F> featureProbs,
Counter<F> backoffProbs, Counter<L> labelProbs,
FeatureExtractor<I, F> featureExtractor) {
super();
this.featureProbs = featureProbs;
this.backoffProbs = backoffProbs;
this.labelProbs = labelProbs;
this.featureExtractor = featureExtractor;
}
}