package edu.stanford.nlp.classify; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.Map; import edu.stanford.nlp.ling.Datum; import edu.stanford.nlp.ling.RVFDatum; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.stats.Counters; import edu.stanford.nlp.util.ErasureUtils; import edu.stanford.nlp.util.Index; /** * A multinomial logistic regression classifier. Please see FlippingProbsLogisticClassifierFactory * or ShiftParamsLogisticClassifierFactory for example use cases. * * This is classic multinomial logistic regression where you have one reference class (the last one) and * (numClasses - 1) times numFeatures weights, unlike the maxent/softmax regression we more normally use. * * @author jtibs */ public class MultinomialLogisticClassifier<L, F> implements ProbabilisticClassifier<L, F>, RVFClassifier<L, F> { private static final long serialVersionUID = 1L; private final double[][] weights; private final Index<F> featureIndex; private final Index<L> labelIndex; /** * @param weights A (numClasses - 1) by numFeatures matrix that holds the weight array for each * class. Note that only (numClasses - 1) rows are needed, as the probability for last class is * uniquely determined by the others. */ public MultinomialLogisticClassifier(double[][] weights, Index<F> featureIndex, Index<L> labelIndex) { this.featureIndex = featureIndex; this.labelIndex = labelIndex; this.weights = weights; } @Override public Collection<L> labels() { return labelIndex.objectsList(); } @Override public L classOf(Datum<L, F> example) { return Counters.argmax(scoresOf(example)); } @Override public Counter<L> scoresOf(Datum<L, F> example) { return logProbabilityOf(example); } @Override public L classOf(RVFDatum<L, F> example) { return classOf((Datum<L, F>)example); } @Override public Counter<L> scoresOf(RVFDatum<L, F> example) { return scoresOf((Datum<L, F>)example); } @Override public Counter<L> probabilityOf(Datum<L, F> example) { // calculate the feature indices and feature values int[] featureIndices = LogisticUtils.indicesOf(example.asFeatures(), featureIndex); double[] featureValues; if (example instanceof RVFDatum<?, ?>) { Collection<Double> featureValuesCollection = ((RVFDatum<?, ?>) example).asFeaturesCounter().values(); featureValues = LogisticUtils.convertToArray(featureValuesCollection); } else { featureValues = new double[example.asFeatures().size()]; Arrays.fill(featureValues, 1.0); } // calculate probability of each class Counter<L> result = new ClassicCounter<>(); int numClasses = labelIndex.size(); double[] sigmoids = LogisticUtils.calculateSigmoids(weights, featureIndices, featureValues); for (int c = 0; c < numClasses; c++) { L label = labelIndex.get(c); result.incrementCount(label, sigmoids[c]); } return result; } @Override public Counter<L> logProbabilityOf(Datum<L, F> example) { Counter<L> result = probabilityOf(example); Counters.logInPlace(result); return result; } private static <LL,FF> MultinomialLogisticClassifier<LL,FF> load(String path) throws IOException, ClassNotFoundException { System.err.print("Loading classifier from " + path + "... "); ObjectInputStream in = new ObjectInputStream(new FileInputStream(path)); double[][] myWeights = ErasureUtils.uncheckedCast(in.readObject()); Index<FF> myFeatureIndex = ErasureUtils.uncheckedCast(in.readObject()); Index<LL> myLabelIndex = ErasureUtils.uncheckedCast(in.readObject()); in.close(); System.err.println("done."); return new MultinomialLogisticClassifier<>(myWeights, myFeatureIndex, myLabelIndex); } private void save(String path) throws IOException { System.out.print("Saving classifier to " + path + "... "); // make sure the directory specified by path exists int lastSlash = path.lastIndexOf(File.separator); if (lastSlash > 0) { File dir = new File(path.substring(0, lastSlash)); if (! dir.exists()) dir.mkdirs(); } ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(path)); out.writeObject(weights); out.writeObject(featureIndex); out.writeObject(labelIndex); out.close(); System.out.println("done."); } public Map<L, Counter<F>> weightsAsGenericCounter() { Map<L, Counter<F>> allweights = new HashMap<>(); for(int i = 0; i < weights.length; i++){ Counter<F> c = new ClassicCounter<>(); L label = labelIndex.get(i); double[] w = weights[i]; for (F f : featureIndex) { int indexf = featureIndex.indexOf(f); if(w[indexf] != 0.0) c.setCount(f, w[indexf]); } allweights.put(label, c); } return allweights; } }