package edu.stanford.nlp.classify; import edu.stanford.nlp.util.Pair; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.ling.Datum; import edu.stanford.nlp.ling.RVFDatum; /** * This class represents a trained SVM Classifier. It is actually just a * LinearClassifier, but it can have a Platt (sigmoid) model overlaying * it for the purpose of producing meaningful probabilities. * * @author Jenny Finkel * @author Sarah Spikes (sdspikes@cs.stanford.edu) (templatization) */ public class SVMLightClassifier<L, F> extends LinearClassifier<L, F> { /** * */ private static final long serialVersionUID = 1L; public LinearClassifier<L, L> platt = null; public SVMLightClassifier(ClassicCounter<Pair<F, L>> weightCounter, ClassicCounter<L> thresholds) { super(weightCounter, thresholds); } public SVMLightClassifier(ClassicCounter<Pair<F, L>> weightCounter, ClassicCounter<L> thresholds, LinearClassifier<L, L> platt) { super(weightCounter, thresholds); this.platt = platt; } public void setPlatt(LinearClassifier<L, L> platt) { this.platt = platt; } /** * Returns a counter for the log probability of each of the classes * looking at the the sum of e^v for each count v, should be 1 * Note: Uses SloppyMath.logSum which isn't exact but isn't as * offensively slow as doing a series of exponentials */ @Override public Counter<L> logProbabilityOf(Datum<L, F> example) { if (platt == null) { throw new UnsupportedOperationException("If you want to ask for the probability, you must train a Platt model!"); } Counter<L> scores = scoresOf(example); scores.incrementCount(null); Counter<L> probs = platt.logProbabilityOf(new RVFDatum<>(scores)); //System.out.println(scores+" "+probs); return probs; } /** * Returns a counter for the log probability of each of the classes * looking at the the sum of e^v for each count v, should be 1 * Note: Uses SloppyMath.logSum which isn't exact but isn't as * offensively slow as doing a series of exponentials */ @Override public Counter<L> logProbabilityOf(RVFDatum<L, F> example) { if (platt == null) { throw new UnsupportedOperationException("If you want to ask for the probability, you must train a Platt model!"); } Counter<L> scores = scoresOf(example); scores.incrementCount(null); Counter<L> probs = platt.logProbabilityOf(new RVFDatum<>(scores)); //System.out.println(scores+" "+probs); return probs; } }