package edu.stanford.nlp.coref.statistical; import java.io.PrintWriter; import java.util.Map; import java.util.SortedMap; import java.util.TreeMap; import edu.stanford.nlp.io.IOUtils; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.stats.Counters; import edu.stanford.nlp.util.Timing; import edu.stanford.nlp.util.logging.Redwood; /** * A simple linear classifier trained by SGD with support for several different loss functions * and learning rate schedules. * @author Kevin Clark */ public class SimpleLinearClassifier { private static Redwood.RedwoodChannels log = Redwood.channels(SimpleLinearClassifier.class); private final Loss defaultLoss; private final LearningRateSchedule learningRateSchedule; private final double regularizationStrength; private final Counter<String> weights; private final Counter<String> accessTimes; private int examplesSeen; public SimpleLinearClassifier(Loss loss, LearningRateSchedule learningRateSchedule, double regularizationStrength) { this(loss, learningRateSchedule, regularizationStrength, null); } public SimpleLinearClassifier(Loss loss,LearningRateSchedule learningRateSchedule, double regularizationStrength, String modelFile) { if (modelFile != null) { try { if (modelFile.endsWith(".tab.gz")) { Timing.startDoing("Reading " + modelFile); this.weights = Counters.deserializeStringCounter(modelFile); Timing.endDoing("Reading " + modelFile); } else { this.weights = IOUtils.readObjectAnnouncingTimingFromURLOrClasspathOrFileSystem( log, "Loading coref model", modelFile); } } catch (Exception e) { throw new RuntimeException("Error leading weights from " + modelFile, e); } } else { this.weights = new ClassicCounter<>(); } this.defaultLoss = loss; this.regularizationStrength = regularizationStrength; this.learningRateSchedule = learningRateSchedule; accessTimes = new ClassicCounter<>(); examplesSeen = 0; } public void learn(Counter<String> features, double label, double weight) { learn(features, label, weight, defaultLoss); } public void learn(Counter<String> features, double label, double weight, Loss loss) { examplesSeen++; double dloss = loss.derivative(label, weightFeatureProduct(features)); for (Map.Entry<String, Double> feature : features.entrySet()) { double dfeature = weight * (-dloss * feature.getValue()); if (dfeature != 0) { String featureName = feature.getKey(); learningRateSchedule.update(featureName, dfeature); double lr = learningRateSchedule.getLearningRate(featureName); double w = weights.getCount(featureName); double dreg = weight * regularizationStrength * (examplesSeen - accessTimes.getCount(featureName)); double afterReg = (w - Math.signum(w) * dreg * lr); weights.setCount(featureName, (Math.signum(afterReg) != Math.signum(w) ? 0 : afterReg) + dfeature * lr); accessTimes.setCount(featureName, examplesSeen); } } } public double label(Counter<String> features) { return defaultLoss.predict(weightFeatureProduct(features)); } public double weightFeatureProduct(Counter<String> features) { double product = 0; for (Map.Entry<String, Double> feature : features.entrySet()) { product += feature.getValue() * weights.getCount(feature.getKey()); } return product; } public void setWeight(String featureName, double weight) { weights.setCount(featureName, weight); } public SortedMap<String, Double> getWeightVector() { SortedMap<String, Double> m = new TreeMap<>((f1, f2) -> { double weightDifference = Math.abs(weights.getCount(f2)) - Math.abs(weights.getCount(f1)); return weightDifference == 0 ? f1.compareTo(f2) : (int) Math.signum(weightDifference); }); weights.entrySet().stream().forEach(e -> m.put(e.getKey(), e.getValue())); return m; } public void printWeightVector() { printWeightVector(null); } public void printWeightVector(PrintWriter writer) { SortedMap<String, Double> sortedWeights = getWeightVector(); for (Map.Entry<String, Double> e : sortedWeights.entrySet()) { if (writer == null) { Redwood.log("scoref.train", e.getKey() + " => " + e.getValue()); } else { writer.println(e.getKey() + " => " + e.getValue()); } } } public void writeWeights(String fname) throws Exception { IOUtils.writeObjectToFile(weights, fname); } // ---------- LOSS FUNCTIONS ---------- public static interface Loss { public double predict(double product); public double derivative(double label, double product); } public static Loss log() { return new Loss() { @Override public double predict(double product) { return (1 - (1 / (1 + Math.exp(product)))); } @Override public double derivative(double label, double product) { return -label / (1 + Math.exp(label * product)); } @Override public String toString() { return "log"; } }; } public static Loss quadraticallySmoothedSVM(final double gamma) { return new Loss() { @Override public double predict(double product) { return product; }; @Override public double derivative(double label, double product) { double mistake = label * product; return mistake >= 1 ? 0 : (mistake >= 1 - gamma ? (mistake - 1) * label / gamma : -label); } @Override public String toString() { return String.format("quadraticallySmoothed(%s)", gamma); } }; } public static Loss hinge() { return quadraticallySmoothedSVM(0); } public static Loss maxMargin(final double h) { return new Loss() { @Override public double predict(double product) { throw new UnsupportedOperationException("Predict not implemented for max margin"); } @Override public double derivative(double label, double product) { return product < -h ? 0 : 1; } @Override public String toString() { return String.format("max-margin(%s)", h); } }; } public static Loss risk() { return new Loss() { @Override public double predict(double product) { return 1 / (1 + Math.exp(product)); } @Override public double derivative(double label, double product) { return -Math.exp(product) / Math.pow(1 + Math.exp(product), 2); } @Override public String toString() { return "risk"; } }; } // ---------- LEARNING RATE SCHEDULES ---------- public static interface LearningRateSchedule { public void update(String feature, double gradient); public double getLearningRate(String feature); } private abstract static class CountBasedLearningRate implements LearningRateSchedule { private final Counter<String> counter; public CountBasedLearningRate() { counter = new ClassicCounter<>(); } @Override public void update(String feature, double gradient) { counter.incrementCount(feature, getCounterIncrement(gradient)); } @Override public double getLearningRate(String feature) { return getLearningRate(counter.getCount(feature)); } public abstract double getCounterIncrement(double gradient); public abstract double getLearningRate(double count); } public static LearningRateSchedule constant(final double eta) { return new LearningRateSchedule() { @Override public double getLearningRate(String feature) { return eta; } @Override public void update(String feature, double gradient) { } @Override public String toString() { return String.format("constant(%s)", eta); } }; } public static LearningRateSchedule invScaling(final double eta, final double p) { return new CountBasedLearningRate() { @Override public double getCounterIncrement(double gradient) { return 1.0; } @Override public double getLearningRate(double count) { return eta / Math.pow(1 + count, p); } @Override public String toString() { return String.format("invScaling(%s, %s)", eta, p); } }; } public static LearningRateSchedule adaGrad(final double eta, final double tau) { return new CountBasedLearningRate() { @Override public double getCounterIncrement(double gradient) { return gradient * gradient; } @Override public double getLearningRate(double count) { return eta / (tau + Math.sqrt(count)); } @Override public String toString() { return String.format("adaGrad(%s, %s)", eta, tau); } }; } }