/////////////////////////////////////////////////////////////////////////////// // Copyright (C) 2008 Carnegie Mellon University and // (C) 2007 University of Texas at Austin and (C) 2005 // University of Pennsylvania and Copyright (C) 2002, 2003 University // of Massachusetts Amherst, Department of Computer Science. // // This software is licensed under the terms of the Common Public // License, Version 1.0 or (at your option) any subsequent version. // // The license is approved by the Open Source Initiative, and is // available from their website at http://www.opensource.org. /////////////////////////////////////////////////////////////////////////////// package mstparser.mallet; import edu.umass.cs.mallet.base.classify.Classification; import edu.umass.cs.mallet.base.classify.Classifier; import edu.umass.cs.mallet.base.classify.MaxEntTrainer; import edu.umass.cs.mallet.base.pipe.Pipe; import edu.umass.cs.mallet.base.pipe.SerialPipes; import edu.umass.cs.mallet.base.pipe.Token2FeatureVector; import edu.umass.cs.mallet.base.types.Alphabet; import edu.umass.cs.mallet.base.types.*; import gnu.trove.TObjectIntHashMap; import java.io.File; import java.io.FileInputStream; import java.io.ObjectInputStream; import java.io.PrintStream; import java.util.Date; import mstparser.FeatureVector; import mstparser.*; /** * @author Dipanjan Das 6/4/08 dipanjan@cs.cmu.edu * * Adapted from code by Ryan McDonald (ryanmcd@google.com) * */ public class LabelClassifier { private ParserOptions options; private int[] instanceLengths; private int[] ignore; private String trainFile; private File trainForest; private DependencyParser parentParser; private DependencyPipe pipe; private Parameters params; public TObjectIntHashMap predCounts = null; private int cutOff = 0; /* * for training */ public LabelClassifier(ParserOptions options, int[] instanceLengths, int[] ignore, String trainFile, File trainForest, DependencyParser parser, DependencyPipe pipe) { this.options = options; this.instanceLengths = instanceLengths; this.ignore = ignore; this.trainFile = trainFile; this.trainForest = trainForest; this.parentParser = parser; this.pipe = pipe; params = new Parameters(pipe.dataAlphabet.size()); cutOff = options.separateLabCutOff; } /* * for test */ public LabelClassifier(ParserOptions options) { cutOff = options.separateLabCutOff; } /* * lab and par do not have the root word * */ public String[] outputLabels(Classifier testClassifier, String[] toks, String[] pos, String[] lab, int[] par, String[] depPred, int[] headPred, DependencyInstance instance) { Pipe p = (SerialPipes) testClassifier.getInstancePipe(); String[] newLab = new String[lab.length + 1]; newLab[0] = newLab[1]; int[] newPar = new int[lab.length + 1]; newPar[0] = -1; for (int i = 1; i < newLab.length; i++) { newLab[i] = lab[i - 1]; newPar[i] = par[i - 1]; } for (int i = 1; i < newPar.length; i++) { String line = newLab[i] + " " + MalletFeatures.getFeats(toks, pos, newLab, newPar, depPred, headPred, i).trim(); Token t = new Token(""); String[] tokens = line.split(" "); for (int j = 1; j < tokens.length; j++) { addFeature(t, tokens[j], 1.0); } PrintStream ps = DependencyParser.out; PrintStream err = System.err; System.setErr(ps); LabelSequence target = new LabelSequence((LabelAlphabet) p.getTargetAlphabet(), 1); target.add(tokens[0]); Instance ins = new Instance(t, target.getLabelAtPosition(0), null, null, p); Classification c = testClassifier.classify(ins); lab[i - 1] = c.getLabeling().getBestLabel().toString(); System.setErr(err); } return lab; } public void passThroughInstancesToSelectFeatureThresholds() throws Exception { ObjectInputStream in = new ObjectInputStream(new FileInputStream(trainForest)); predCounts = new TObjectIntHashMap(); int numInstances = instanceLengths.length; int j = 0; DependencyParser.out.println("Passing through instances to select feature thresholds..."); DependencyParser.out.print(" ["); long startTime = (new Date()).getTime(); for (int i = 0; i < numInstances; i++) { if ((i + 1) % 500 == 0) { DependencyParser.out.print((i + 1) + ","); } //afm 03-06-08 int length = instanceLengths[i]; // Get production crap. FeatureVector[][][] fvs = new FeatureVector[length][length][2]; double[][][] probs = new double[length][length][2]; FeatureVector[][][][] nt_fvs = new FeatureVector[length][pipe.types.length][2][2]; double[][][][] nt_probs = new double[length][pipe.types.length][2][2]; FeatureVector[][][] fvs_trips = new FeatureVector[length][length][length]; double[][][] probs_trips = new double[length][length][length]; FeatureVector[][][] fvs_sibs = new FeatureVector[length][length][2]; double[][][] probs_sibs = new double[length][length][2]; DependencyInstance inst; if (options.secondOrder) { inst = ((DependencyPipe2O) pipe).readInstance(in, length, fvs, probs, fvs_trips, probs_trips, fvs_sibs, probs_sibs, nt_fvs, nt_probs, params); } else { inst = pipe.readInstance(in, length, fvs, probs, nt_fvs, nt_probs, params); } if (ignore[i] != 0) { continue; } updatePredCounts(inst); j++; } in.close(); Object[] keys = predCounts.keys(); for (int i = 0; i < keys.length; i++) { if (predCounts.get(keys[i]) <= cutOff) { predCounts.remove(keys[i]); } } long endTime = (new Date()).getTime(); DependencyParser.out.println(" |Time:" + (endTime - startTime) + "]"); DependencyParser.out.println("No of active features:" + predCounts.size()); } public void updatePredCounts(DependencyInstance inst) { String[] toks = inst.forms; String[] pos = inst.postags; String[] lab = inst.deprels; int[] par = inst.heads; String[] dep_pred = null; int[] heads_pred = null; if (options.stackedLevel1) { dep_pred = inst.deprels_pred; heads_pred = inst.heads_pred; } for (int i = 1; i < par.length; i++) { String line = lab[i] + " " + MalletFeatures.getFeats(toks, pos, lab, par, dep_pred, heads_pred, i).trim(); String[] tokens = line.split(" "); for (int j = 1; j < tokens.length; j++) { if (predCounts.contains(tokens[j])) { predCounts.increment(tokens[j]); } else { predCounts.put(tokens[j], 1); } } } } public void addFeatureVectorsForInstance(DependencyInstance inst, InstanceList trainData, Pipe p) { String[] toks = inst.forms; String[] pos = inst.postags; String[] lab = inst.deprels; int[] par = inst.heads; String[] dep_pred = null; int[] heads_pred = null; if (options.stackedLevel1) { dep_pred = inst.deprels_pred; heads_pred = inst.heads_pred; } for (int i = 1; i < par.length; i++) { String line = lab[i] + " " + MalletFeatures.getFeats(toks, pos, lab, par, dep_pred, heads_pred, i).trim(); Token t = new Token(""); String[] tokens = line.split(" "); for (int j = 1; j < tokens.length; j++) { addFeature(t, tokens[j], 1.0); } LabelSequence target = new LabelSequence((LabelAlphabet) p.getTargetAlphabet(), 1); target.add(tokens[0]); Instance ins = new Instance(t, target.getLabelAtPosition(0), null, null, p); trainData.add(ins); } } public void addFeature(Token t, String f, double v) { if (predCounts == null || predCounts.size() == 0 || predCounts.contains(f)) { t.setFeatureValue(f, v); } } public Classifier trainClassifier(int numIters) throws Exception { Pipe p = new SerialPipes(new Pipe[]{ new Token2FeatureVector(true, true) }); p.setDataAlphabet(new Alphabet()); p.setTargetAlphabet(new LabelAlphabet()); InstanceList trainData = new InstanceList(p); passThroughInstancesToSelectFeatureThresholds(); ObjectInputStream in = new ObjectInputStream(new FileInputStream(trainForest)); int numInstances = instanceLengths.length; int j = 0; DependencyParser.out.println("Storing feature vectors in Mallet data structure..."); DependencyParser.out.print(" ["); long startTime = (new Date()).getTime(); for (int i = 0; i < numInstances; i++) { if ((i + 1) % 500 == 0) { DependencyParser.out.print((i + 1) + ","); } int length = instanceLengths[i]; // Get production crap. FeatureVector[][][] fvs = new FeatureVector[length][length][2]; double[][][] probs = new double[length][length][2]; FeatureVector[][][][] nt_fvs = new FeatureVector[length][pipe.types.length][2][2]; double[][][][] nt_probs = new double[length][pipe.types.length][2][2]; FeatureVector[][][] fvs_trips = new FeatureVector[length][length][length]; double[][][] probs_trips = new double[length][length][length]; FeatureVector[][][] fvs_sibs = new FeatureVector[length][length][2]; double[][][] probs_sibs = new double[length][length][2]; DependencyInstance inst; if (options.secondOrder) { inst = ((DependencyPipe2O) pipe).readInstance(in, length, fvs, probs, fvs_trips, probs_trips, fvs_sibs, probs_sibs, nt_fvs, nt_probs, params); } else { inst = pipe.readInstance(in, length, fvs, probs, nt_fvs, nt_probs, params); } if (ignore[i] != 0) { continue; } addFeatureVectorsForInstance(inst, trainData, p); j++; } in.close(); predCounts = null; Alphabet dataAlph = p.getDataAlphabet(); Alphabet tAlph = p.getTargetAlphabet(); long endTime = (new Date()).getTime(); DependencyParser.out.println(" |Time:" + (endTime - startTime) + "]"); DependencyParser.out.println("Number of labels: " + tAlph.size()); DependencyParser.out.println("Number of Predicates: " + dataAlph.size()); DependencyParser.out.println("Training size: " + trainData.size()); MaxEntTrainer met = new MaxEntTrainer(1.0); met.setNumIterations(numIters); PrintStream ps = DependencyParser.out; PrintStream err = System.err; System.setErr(ps); Classifier classifier = met.train(trainData); dataAlph.stopGrowth(); ((LabelAlphabet) p.getTargetAlphabet()).stopGrowth(); DependencyParser.out.println("Done Training Labeler\nReturning classifier"); System.setErr(err); return classifier; } }