/*Copyright 2014, Language Technologies Institute, Carnegie Mellon University Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package cmu.arktweetnlp; import java.io.IOException; import java.util.ArrayList; import cmu.arktweetnlp.impl.Model; import cmu.arktweetnlp.impl.ModelSentence; import cmu.arktweetnlp.impl.OWLQN; import cmu.arktweetnlp.impl.Sentence; import cmu.arktweetnlp.impl.OWLQN.WeightsPrinter; import cmu.arktweetnlp.impl.features.FeatureExtractor; import cmu.arktweetnlp.io.CoNLLReader; import cmu.arktweetnlp.util.Util; import edu.stanford.nlp.math.ArrayMath; import edu.stanford.nlp.optimization.DiffFunction; public class Train { public double l2penalty = 2; public double l1penalty = 0.25; public double tol = 1e-7; public int maxIter = 500; public String modelLoadFilename = null; public String examplesFilename = null; public String modelSaveFilename = null; public boolean dumpFeatures = false; // Data structures private ArrayList<Sentence> lSentences; private ArrayList<ModelSentence> mSentences; private int numTokens = 0; private Model model; Train() { lSentences = new ArrayList<Sentence>(); mSentences = new ArrayList<ModelSentence>(); model = new Model(); } public void doFeatureDumping() throws IOException { readTrainingSentences(examplesFilename); constructLabelVocab(); extractFeatures(); dumpFeatures(); } public void doTraining() throws IOException { readTrainingSentences(examplesFilename); constructLabelVocab(); extractFeatures(); model.lockdownAfterFeatureExtraction(); if (modelLoadFilename != null) { readWarmStartModel(); } optimizationLoop(); model.saveModelAsText(modelSaveFilename); } public void readTrainingSentences(String filename) throws IOException { lSentences = CoNLLReader.readFile(filename); for (Sentence sent : lSentences) numTokens += sent.T(); } public void constructLabelVocab() { for (Sentence s : lSentences) { for (String l : s.labels) { model.labelVocab.num(l); } } model.labelVocab.lock(); model.numLabels = model.labelVocab.size(); } public void dumpFeatures() throws IOException { FeatureExtractor fe = new FeatureExtractor(model, true); fe.dumpMode = true; for (Sentence lSent : lSentences) { ModelSentence mSent = new ModelSentence(lSent.T()); fe.computeFeatures(lSent, mSent); } } public void extractFeatures() throws IOException { System.out.println("Extracting features"); FeatureExtractor fe = new FeatureExtractor(model, true); for (Sentence lSent : lSentences) { ModelSentence mSent = new ModelSentence(lSent.T()); fe.computeFeatures(lSent, mSent); mSentences.add(mSent); } } public void readWarmStartModel() throws IOException { assert model.featureVocab.isLocked(); Model warmModel = Model.loadModelFromText(modelLoadFilename); Model.copyCoefsForIntersectingFeatures(warmModel, model); } public void optimizationLoop() { OWLQN minimizer = new OWLQN(); minimizer.setMaxIters(maxIter); minimizer.setQuiet(false); minimizer.setWeightsPrinting(new MyWeightsPrinter()); double[] initialWeights = model.convertCoefsToFlat(); double[] finalWeights = minimizer.minimize( new GradientCalculator(), initialWeights, l1penalty, tol, 5); model.setCoefsFromFlat(finalWeights); } private class GradientCalculator implements DiffFunction { @Override public int domainDimension() { return model.flatIDsize(); } @Override public double valueAt(double[] flatCoefs) { model.setCoefsFromFlat(flatCoefs); double loglik = 0; for (ModelSentence s : mSentences) { loglik += model.computeLogLik(s); } return -loglik + regularizerValue(flatCoefs); } @Override public double[] derivativeAt(double[] flatCoefs) { double[] g = new double[model.flatIDsize()]; model.setCoefsFromFlat(flatCoefs); for (ModelSentence s : mSentences) { model.computeGradient(s, g); } ArrayMath.multiplyInPlace(g, -1); addL2regularizerGradient(g, flatCoefs); return g; } } private void addL2regularizerGradient(double[] grad, double[] flatCoefs) { assert grad.length == flatCoefs.length; for (int f=0; f < flatCoefs.length; f++) { grad[f] += l2penalty * flatCoefs[f]; } } /** * lambda_2 * (1/2) sum (beta_j)^2 + lambda_1 * sum |beta_j| * our OWLQN seems to only want the first term */ private double regularizerValue(double[] flatCoefs) { double l2_term = 0; for (int f=0; f < flatCoefs.length; f++) { l2_term += Math.pow(flatCoefs[f], 2); } return 0.5*l2penalty*l2_term; } public class MyWeightsPrinter implements WeightsPrinter { @Override public void printWeights() { double loglik = 0; for (ModelSentence s : mSentences) { loglik += model.computeLogLik(s); } System.out.printf("\tTokLL %.6f\t", loglik/numTokens); } } ////////////////////////////////////////////////////////////// public static void main(String[] args) throws IOException { Train trainer = new Train(); if (args.length < 2 || args[0].equals("-h") || args[1].equals("--help")) { usage(); } int i=0; while (i < args.length) { // Util.p(args[i]); if (!args[i].startsWith("-")) { break; } else if (args[i].equals("--warm-start")) { trainer.modelLoadFilename = args[i+1]; i += 2; } else if (args[i].equals("--max-iter")) { trainer.maxIter = Integer.parseInt(args[i+1]); i += 2; } else if (args[i].equals("--dump-feat")) { trainer.dumpFeatures = true; i += 1; } else if (args[i].equals("--l2")) { trainer.l2penalty = Double.parseDouble(args[i+1]); i += 2; } else if (args[i].equals("--l1")) { trainer.l1penalty = Double.parseDouble(args[i+1]); i += 2; } else { usage(); } } if (trainer.dumpFeatures) { trainer.examplesFilename = args[i]; trainer.doFeatureDumping(); System.exit(0); } if (args.length - i < 2) usage(); trainer.examplesFilename = args[i]; trainer.modelSaveFilename = args[i+1]; trainer.doTraining(); } public static void usage() { System.out.println( "Train [options] <ExamplesFilename> <ModelOutputFilename>\n" + "Options:" + "\n --max-iter <n>" + "\n --warm-start <modelfile> Initializes at weights of this model. discards base features that aren't in training set." + "\n --dump-feat Show extracted features, instead of training. Useful for debugging/analyzing feature extractors." + "\n" ); System.exit(1); } }