package edu.stanford.nlp.classify; import edu.stanford.nlp.optimization.GoldenSectionLineSearch; import edu.stanford.nlp.stats.*; import edu.stanford.nlp.util.*; import edu.stanford.nlp.ling.Datum; import edu.stanford.nlp.ling.RVFDatum; import edu.stanford.nlp.optimization.LineSearcher; import java.io.*; import java.text.NumberFormat; import java.util.*; import java.util.function.Function; import java.util.regex.Pattern; import edu.stanford.nlp.util.logging.Redwood; /** * This class is meant for training SVMs ({@link SVMLightClassifier}s). It actually calls SVM Light, or * SVM Struct for multiclass SVMs, or SVM perf is the option is enabled, on the command line, reads in the produced * model file and creates a Linear Classifier. A Platt model is also trained * (unless otherwise specified) on top of the SVM so that probabilities can * be produced. For multiclass classifier, you have to set C using setC otherwise the code will not run (by sonalg). * * @author Jenny Finkel * @author Aria Haghighi * @author Sarah Spikes (sdspikes@cs.stanford.edu) (templatization) */ public class SVMLightClassifierFactory<L, F> implements ClassifierFactory<L, F, SVMLightClassifier<L,F>>{ //extends AbstractLinearClassifierFactory { /** * */ private static final long serialVersionUID = 1L; /** * C can be tuned using held-out set or cross-validation * For binary SVM, if C=0, svmlight uses default of 1/(avg x*x) */ protected double C = -1.0; private boolean useSigmoid = false; protected boolean verbose = true; private String svmLightLearn = "/u/nlp/packages/svm_light/svm_learn"; private String svmStructLearn = "/u/nlp/packages/svm_multiclass/svm_multiclass_learn"; private String svmPerfLearn = "/u/nlp/packages/svm_perf/svm_perf_learn"; private String svmLightClassify = "/u/nlp/packages/svm_light/svm_classify"; private String svmStructClassify = "/u/nlp/packages/svm_multiclass/svm_multiclass_classify"; private String svmPerfClassify = "/u/nlp/packages/svm_perf/svm_perf_classify"; private boolean useAlphaFile = false; protected File alphaFile; private boolean deleteTempFilesOnExit = true; private int svmLightVerbosity = 0; // not verbose private boolean doEval = false; private boolean useSVMPerf = false; final static Redwood.RedwoodChannels logger = Redwood.channels(SVMLightClassifierFactory.class); /** @param svmLightLearn is the fullPathname of the training program of svmLight with default value "/u/nlp/packages/svm_light/svm_learn" * @param svmStructLearn is the fullPathname of the training program of svmMultiClass with default value "/u/nlp/packages/svm_multiclass/svm_multiclass_learn" * @param svmPerfLearn is the fullPathname of the training program of svmMultiClass with default value "/u/nlp/packages/svm_perf/svm_perf_learn" */ public SVMLightClassifierFactory(String svmLightLearn, String svmStructLearn, String svmPerfLearn){ this.svmLightLearn = svmLightLearn; this.svmStructLearn = svmStructLearn; this.svmPerfLearn = svmPerfLearn; } public SVMLightClassifierFactory(){ } public SVMLightClassifierFactory(boolean useSVMPerf){ this.useSVMPerf = useSVMPerf; } /** * Set the C parameter (for the slack variables) for training the SVM. */ public void setC(double C) { this.C = C; } /** * Get the C parameter (for the slack variables) for training the SVM. */ public double getC() { return C; } /** * Specify whether or not to train an overlying platt (sigmoid) * model for producing meaningful probabilities. */ public void setUseSigmoid(boolean useSigmoid) { this.useSigmoid = useSigmoid; } /** * Get whether or not to train an overlying platt (sigmoid) * model for producing meaningful probabilities. */ public boolean getUseSigma() { return useSigmoid; } public boolean getDeleteTempFilesOnExitFlag() { return deleteTempFilesOnExit; } public void setDeleteTempFilesOnExitFlag(boolean deleteTempFilesOnExit) { this.deleteTempFilesOnExit = deleteTempFilesOnExit; } /** * Reads in a model file in svm light format. It needs to know if its multiclass or not * because it affects the number of header lines. Maybe there is another way to tell and we * can remove this flag? */ private static Pair<Double, ClassicCounter<Integer>> readModel(File modelFile, boolean multiclass) { int modelLineCount = 0; try { int numLinesToSkip = multiclass ? 13 : 10; String stopToken = "#"; BufferedReader in = new BufferedReader(new FileReader(modelFile)); for (int i=0; i < numLinesToSkip; i++) { in.readLine(); modelLineCount ++; } List<Pair<Double, ClassicCounter<Integer>>> supportVectors = new ArrayList<>(); // Read Threshold String thresholdLine = in.readLine(); modelLineCount ++; String[] pieces = thresholdLine.split("\\s+"); double threshold = Double.parseDouble(pieces[0]); // Read Support Vectors while (in.ready()) { String svLine = in.readLine(); modelLineCount ++; pieces = svLine.split("\\s+"); // First Element is the alpha_i * y_i double alpha = Double.parseDouble(pieces[0]); ClassicCounter<Integer> supportVector = new ClassicCounter<>(); for (int i=1; i < pieces.length; ++i) { String piece = pieces[i]; if (piece.equals(stopToken)) break; // Each in featureIndex:num class String[] indexNum = piece.split(":"); String featureIndex = indexNum[0]; // mihai: we may see "qid" as indexNum[0]. just skip this piece. this is the block id useful only for reranking, which we don't do here. if(! featureIndex.equals("qid")){ double count = Double.parseDouble(indexNum[1]); supportVector.incrementCount(Integer.valueOf(featureIndex), count); } } supportVectors.add(new Pair<>(alpha, supportVector)); } in.close(); return new Pair<>(threshold, getWeights(supportVectors)); } catch (Exception e) { e.printStackTrace(); throw new RuntimeException("Error reading SVM model (line " + modelLineCount + " in file " + modelFile.getAbsolutePath() + ")"); } } /** * Takes all the support vectors, and their corresponding alphas, and computes a weight * vector that can be used in a vanilla LinearClassifier. This only works because * we are using a linear kernel. The Counter is over the feature indices (+1 cos for * some reason svm_light is 1-indexed), not features. */ private static ClassicCounter<Integer> getWeights(List<Pair<Double, ClassicCounter<Integer>>> supportVectors) { ClassicCounter<Integer> weights = new ClassicCounter<>(); for (Pair<Double, ClassicCounter<Integer>> sv : supportVectors) { ClassicCounter<Integer> c = new ClassicCounter<>(sv.second()); Counters.multiplyInPlace(c, sv.first()); Counters.addInPlace(weights, c); } return weights; } /** * Converts the weight Counter to be from indexed, svm_light format, to a format * we can use in our LinearClassifier. */ private ClassicCounter<Pair<F, L>> convertWeights(ClassicCounter<Integer> weights, Index<F> featureIndex, Index<L> labelIndex, boolean multiclass) { return multiclass ? convertSVMStructWeights(weights, featureIndex, labelIndex) : convertSVMLightWeights(weights, featureIndex, labelIndex); } /** * Converts the svm_light weight Counter (which uses feature indices) into a weight Counter * using the actual features and labels. Because this is svm_light, and not svm_struct, the * weights for the +1 class (which correspond to labelIndex.get(0)) and the -1 class * (which correspond to labelIndex.get(1)) are just the negation of one another. */ private ClassicCounter<Pair<F, L>> convertSVMLightWeights(ClassicCounter<Integer> weights, Index<F> featureIndex, Index<L> labelIndex) { ClassicCounter<Pair<F, L>> newWeights = new ClassicCounter<>(); for (int i : weights.keySet()) { F f = featureIndex.get(i-1); double w = weights.getCount(i); // the first guy in the labelIndex was the +1 class and the second guy // was the -1 class newWeights.incrementCount(new Pair<>(f, labelIndex.get(0)),w); newWeights.incrementCount(new Pair<>(f, labelIndex.get(1)),-w); } return newWeights; } /** * Converts the svm_struct weight Counter (in which the weight for a feature/label pair * correspondes to ((labelIndex * numFeatures)+(featureIndex+1))) into a weight Counter * using the actual features and labels. */ private ClassicCounter<Pair<F, L>> convertSVMStructWeights(ClassicCounter<Integer> weights, Index<F> featureIndex, Index<L> labelIndex) { // int numLabels = labelIndex.size(); int numFeatures = featureIndex.size(); ClassicCounter<Pair<F, L>> newWeights = new ClassicCounter<>(); for (int i : weights.keySet()) { L l = labelIndex.get((i-1) / numFeatures); // integer division on purpose F f = featureIndex.get((i-1) % numFeatures); double w = weights.getCount(i); newWeights.incrementCount(new Pair<>(f, l),w); } return newWeights; } /** * Builds a sigmoid model to turn the classifier outputs into probabilities. */ private LinearClassifier<L, L> fitSigmoid(SVMLightClassifier<L, F> classifier, GeneralDataset<L, F> dataset) { RVFDataset<L, L> plattDataset = new RVFDataset<>(); for (int i = 0; i < dataset.size(); i++) { RVFDatum<L, F> d = dataset.getRVFDatum(i); Counter<L> scores = classifier.scoresOf((Datum<L,F>)d); scores.incrementCount(null); plattDataset.add(new RVFDatum<>(scores, d.label())); } LinearClassifierFactory<L, L> factory = new LinearClassifierFactory<>(); factory.setPrior(new LogPrior(LogPrior.LogPriorType.NULL)); return factory.trainClassifier(plattDataset); } /** * This method will cross validate on the given data and number of folds * to find the optimal C. The scorer is how you determine what to * optimize for (F-score, accuracy, etc). The C is then saved, so that * if you train a classifier after calling this method, that C will be used. */ public void crossValidateSetC(GeneralDataset<L, F> dataset, int numFolds, final Scorer<L> scorer, LineSearcher minimizer) { System.out.println("in Cross Validate"); useAlphaFile = true; boolean oldUseSigmoid = useSigmoid; useSigmoid = false; final CrossValidator<L, F> crossValidator = new CrossValidator<>(dataset, numFolds); final Function<Triple<GeneralDataset<L, F>,GeneralDataset<L, F>,CrossValidator.SavedState>,Double> score = fold -> { GeneralDataset<L, F> trainSet = fold.first(); GeneralDataset<L, F> devSet = fold.second(); alphaFile = (File)fold.third().state; //train(trainSet,true,true); SVMLightClassifier<L, F> classifier = trainClassifierBasic(trainSet); fold.third().state = alphaFile; return scorer.score(classifier,devSet); }; Function<Double,Double> negativeScorer = cToTry -> { C = cToTry; if (verbose) { System.out.print("C = "+cToTry+" "); } Double averageScore = crossValidator.computeAverage(score); if (verbose) { System.out.println(" -> average Score: "+averageScore); } return -averageScore; }; C = minimizer.minimize(negativeScorer); useAlphaFile = false; useSigmoid = oldUseSigmoid; } public void heldOutSetC(GeneralDataset<L, F> train, double percentHeldOut, final Scorer<L> scorer, LineSearcher minimizer) { Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> data = train.split(percentHeldOut); heldOutSetC(data.first(), data.second(), scorer, minimizer); } /** * This method will cross validate on the given data and number of folds * to find the optimal C. The scorer is how you determine what to * optimize for (F-score, accuracy, etc). The C is then saved, so that * if you train a classifier after calling this method, that C will be used. */ public void heldOutSetC(final GeneralDataset<L, F> trainSet, final GeneralDataset<L, F> devSet, final Scorer<L> scorer, LineSearcher minimizer) { useAlphaFile = true; boolean oldUseSigmoid = useSigmoid; useSigmoid = false; Function<Double,Double> negativeScorer = cToTry -> { C = cToTry; SVMLightClassifier<L, F> classifier = trainClassifierBasic(trainSet); double score = scorer.score(classifier,devSet); return -score; }; C = minimizer.minimize(negativeScorer); useAlphaFile = false; useSigmoid = oldUseSigmoid; } private boolean tuneHeldOut = false; private boolean tuneCV = false; private Scorer<L> scorer = new MultiClassAccuracyStats<>(); private LineSearcher tuneMinimizer = new GoldenSectionLineSearch(true); private int folds; private double heldOutPercent; public double getHeldOutPercent() { return heldOutPercent; } public void setHeldOutPercent(double heldOutPercent) { this.heldOutPercent = heldOutPercent; } public int getFolds() { return folds; } public void setFolds(int folds) { this.folds = folds; } public LineSearcher getTuneMinimizer() { return tuneMinimizer; } public void setTuneMinimizer(LineSearcher minimizer) { this.tuneMinimizer = minimizer; } public Scorer getScorer() { return scorer; } public void setScorer(Scorer<L> scorer) { this.scorer = scorer; } public boolean getTuneCV() { return tuneCV; } public void setTuneCV(boolean tuneCV) { this.tuneCV = tuneCV; } public boolean getTuneHeldOut() { return tuneHeldOut; } public void setTuneHeldOut(boolean tuneHeldOut) { this.tuneHeldOut = tuneHeldOut; } public int getSvmLightVerbosity() { return svmLightVerbosity; } public void setSvmLightVerbosity(int svmLightVerbosity) { this.svmLightVerbosity = svmLightVerbosity; } public SVMLightClassifier<L, F> trainClassifier(GeneralDataset<L, F> dataset) { if (tuneHeldOut) { heldOutSetC(dataset, heldOutPercent, scorer, tuneMinimizer); } else if (tuneCV) { crossValidateSetC(dataset, folds, scorer, tuneMinimizer); } return trainClassifierBasic(dataset); } Pattern whitespacePattern = Pattern.compile("\\s+"); public SVMLightClassifier<L, F> trainClassifierBasic(GeneralDataset<L, F> dataset) { Index<L> labelIndex = dataset.labelIndex(); Index<F> featureIndex = dataset.featureIndex; boolean multiclass = (dataset.numClasses() > 2); try { // this is the file that the model will be saved to File modelFile = File.createTempFile("svm-", ".model"); if (deleteTempFilesOnExit) { modelFile.deleteOnExit(); } // this is the file that the svm light formated dataset // will be printed to File dataFile = File.createTempFile("svm-", ".data"); if (deleteTempFilesOnExit) { dataFile.deleteOnExit(); } // print the dataset PrintWriter pw = new PrintWriter(new FileWriter(dataFile)); dataset.printSVMLightFormat(pw); pw.close(); // -v 0 makes it not verbose // -m 400 gives it a larger cache, for faster training String cmd = (multiclass ? svmStructLearn : (useSVMPerf ? svmPerfLearn : svmLightLearn)) + " -v " + svmLightVerbosity + " -m 400 "; // set the value of C if we have one specified if (C > 0.0) cmd = cmd + " -c " + C + " "; // C value else if(useSVMPerf) cmd = cmd + " -c " + 0.01 + " "; //It's required to specify this parameter for SVM perf // Alpha File if (useAlphaFile) { File newAlphaFile = File.createTempFile("svm-", ".alphas"); if (deleteTempFilesOnExit) { newAlphaFile.deleteOnExit(); } cmd = cmd + " -a " + newAlphaFile.getAbsolutePath(); if (alphaFile != null) { cmd = cmd + " -y " + alphaFile.getAbsolutePath(); } alphaFile = newAlphaFile; } // File and Model Data cmd = cmd + " " + dataFile.getAbsolutePath() + " " + modelFile.getAbsolutePath(); if (verbose) logger.info("<< "+cmd+" >>"); /*Process p = Runtime.getRuntime().exec(cmd); p.waitFor(); if (p.exitValue() != 0) throw new RuntimeException("Error Training SVM Light exit value: " + p.exitValue()); p.destroy(); */ SystemUtils.run(new ProcessBuilder(whitespacePattern.split(cmd)), new PrintWriter(System.err), new PrintWriter(System.err)); if (doEval) { File predictFile = File.createTempFile("svm-", ".pred"); if (deleteTempFilesOnExit) { predictFile.deleteOnExit(); } String evalCmd = (multiclass ? svmStructClassify : (useSVMPerf ? svmPerfClassify : svmLightClassify)) + " " + dataFile.getAbsolutePath() + " " + modelFile.getAbsolutePath() + " " + predictFile.getAbsolutePath(); if (verbose) logger.info("<< " + evalCmd + " >>"); SystemUtils.run(new ProcessBuilder(whitespacePattern.split(evalCmd)), new PrintWriter(System.err), new PrintWriter(System.err)); } // read in the model file Pair<Double, ClassicCounter<Integer>> weightsAndThresh = readModel(modelFile, multiclass); double threshold = weightsAndThresh.first(); ClassicCounter<Pair<F, L>> weights = convertWeights(weightsAndThresh.second(), featureIndex, labelIndex, multiclass); ClassicCounter<L> thresholds = new ClassicCounter<>(); if (!multiclass) { thresholds.setCount(labelIndex.get(0), -threshold); thresholds.setCount(labelIndex.get(1), threshold); } SVMLightClassifier<L, F> classifier = new SVMLightClassifier<>(weights, thresholds); if (doEval) { File predictFile = File.createTempFile("svm-", ".pred2"); if (deleteTempFilesOnExit) { predictFile.deleteOnExit(); } PrintWriter pw2 = new PrintWriter(predictFile); NumberFormat nf = NumberFormat.getNumberInstance(); nf.setMaximumFractionDigits(5); for (Datum<L,F> datum:dataset) { Counter<L> scores = classifier.scoresOf(datum); pw2.println(Counters.toString(scores, nf)); } pw2.close(); } if (useSigmoid) { if (verbose) System.out.print("fitting sigmoid..."); classifier.setPlatt(fitSigmoid(classifier, dataset)); if (verbose) System.out.println("done"); } return classifier; } catch (Exception e) { throw new RuntimeException(e); } } }