// Stanford Parser -- a probabilistic lexicalized NL CFG parser // Copyright (c) 2002 - 2014 The Board of Trustees of // The Leland Stanford Junior University. All Rights Reserved. // // This program is free software; you can redistribute it and/or // modify it under the terms of the GNU General Public License // as published by the Free Software Foundation; either version 2 // of the License, or (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with this program; if not, write to the Free Software Foundation, // Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. // // For more information, bug reports, fixes, contact: // Christopher Manning // Dept of Computer Science, Gates 1A // Stanford CA 94305-9010 // USA // parser-support@lists.stanford.edu // http://nlp.stanford.edu/software/srparser.shtml package edu.stanford.nlp.parser.shiftreduce; import java.io.FileFilter; import java.io.IOException; import java.io.ObjectInputStream; import java.io.Serializable; import java.text.DecimalFormat; import java.text.NumberFormat; import java.util.Collection; import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.PriorityQueue; import java.util.Random; import java.util.Set; import edu.stanford.nlp.io.IOUtils; import edu.stanford.nlp.io.RuntimeIOException; import edu.stanford.nlp.ling.CoreLabel; import edu.stanford.nlp.ling.HasTag; import edu.stanford.nlp.ling.HasWord; import edu.stanford.nlp.ling.Label; import edu.stanford.nlp.ling.TaggedWord; import edu.stanford.nlp.ling.Word; import edu.stanford.nlp.parser.common.ArgUtils; import edu.stanford.nlp.parser.common.ParserConstraint; import edu.stanford.nlp.parser.common.ParserGrammar; import edu.stanford.nlp.parser.common.ParserQuery; import edu.stanford.nlp.parser.common.ParserUtils; import edu.stanford.nlp.parser.lexparser.BinaryHeadFinder; import edu.stanford.nlp.parser.lexparser.EvaluateTreebank; import edu.stanford.nlp.parser.lexparser.Options; import edu.stanford.nlp.parser.lexparser.TreebankLangParserParams; import edu.stanford.nlp.parser.lexparser.TreeBinarizer; import edu.stanford.nlp.parser.metrics.ParserQueryEval; import edu.stanford.nlp.parser.metrics.Eval; import edu.stanford.nlp.stats.IntCounter; import edu.stanford.nlp.tagger.common.Tagger; import edu.stanford.nlp.trees.BasicCategoryTreeTransformer; import edu.stanford.nlp.trees.CompositeTreeTransformer; import edu.stanford.nlp.trees.HeadFinder; import edu.stanford.nlp.trees.LabeledScoredTreeNode; import edu.stanford.nlp.trees.Tree; import edu.stanford.nlp.trees.Treebank; import edu.stanford.nlp.trees.TreebankLanguagePack; import edu.stanford.nlp.trees.TreeCoreAnnotations; import edu.stanford.nlp.trees.Trees; import edu.stanford.nlp.util.ArrayUtils; import edu.stanford.nlp.util.CollectionUtils; import edu.stanford.nlp.util.ErasureUtils; import edu.stanford.nlp.util.Function; import edu.stanford.nlp.util.Generics; import edu.stanford.nlp.util.HashIndex; import edu.stanford.nlp.util.Index; import edu.stanford.nlp.util.Pair; import edu.stanford.nlp.util.ReflectionLoading; import edu.stanford.nlp.util.ScoredComparator; import edu.stanford.nlp.util.ScoredObject; import edu.stanford.nlp.util.StringUtils; import edu.stanford.nlp.util.Timing; import edu.stanford.nlp.util.Triple; import edu.stanford.nlp.util.concurrent.MulticoreWrapper; import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor; /** * Overview and description available at * http://nlp.stanford.edu/software/srparser.shtml * * @author John Bauer */ public class ShiftReduceParser extends ParserGrammar implements Serializable { Index<Transition> transitionIndex; Map<String, Weight> featureWeights; //final Map<String, List<ScoredObject<Integer>>> featureWeights; ShiftReduceOptions op; FeatureFactory featureFactory; Set<String> knownStates; public ShiftReduceParser(ShiftReduceOptions op) { this.transitionIndex = new HashIndex<Transition>(); this.featureWeights = Generics.newHashMap(); this.op = op; this.knownStates = Generics.newHashSet(); String[] classes = op.featureFactoryClass.split(";"); if (classes.length == 1) { this.featureFactory = ReflectionLoading.loadByReflection(classes[0]); } else { FeatureFactory[] factories = new FeatureFactory[classes.length]; for (int i = 0; i < classes.length; ++i) { int paren = classes[i].indexOf("("); if (paren >= 0) { String arg = classes[i].substring(paren + 1, classes[i].length() - 1); factories[i] = ReflectionLoading.loadByReflection(classes[i].substring(0, paren), arg); } else { factories[i] = ReflectionLoading.loadByReflection(classes[i]); } } this.featureFactory = new CombinationFeatureFactory(factories); } } private ShiftReduceParser(ShiftReduceOptions op, FeatureFactory factory) { this.transitionIndex = new HashIndex<Transition>(); this.featureWeights = Generics.newHashMap(); this.op = op; this.featureFactory = factory; this.knownStates = Generics.newHashSet(); } /* private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { ObjectInputStream.GetField fields = in.readFields(); transitionIndex = ErasureUtils.uncheckedCast(fields.get("transitionIndex", null)); op = ErasureUtils.uncheckedCast(fields.get("op", null)); featureFactory = ErasureUtils.uncheckedCast(fields.get("featureFactory", null)); featureWeights = Generics.newHashMap(); Map<String, List<ScoredObject<Integer>>> oldWeights = ErasureUtils.uncheckedCast(fields.get("featureWeights", null)); for (String feature : oldWeights.keySet()) { List<ScoredObject<Integer>> oldFeature = oldWeights.get(feature); Weight newFeature = new Weight(); for (int i = 0; i < oldFeature.size(); ++i) { newFeature.updateWeight(oldFeature.get(i).object(), (float) oldFeature.get(i).score()); } featureWeights.put(feature, newFeature); } } */ @Override public Options getOp() { return op; } @Override public TreebankLangParserParams getTLPParams() { return op.tlpParams; } @Override public TreebankLanguagePack treebankLanguagePack() { return getTLPParams().treebankLanguagePack(); } private final static String[] BEAM_FLAGS = { "-beamSize", "4" }; @Override public String[] defaultCoreNLPFlags() { if (op.trainOptions().beamSize > 1) { return ArrayUtils.concatenate(getTLPParams().defaultCoreNLPFlags(), BEAM_FLAGS); } else { // TODO: this may result in some options which are useless for // this model, such as -retainTmpSubcategories return getTLPParams().defaultCoreNLPFlags(); } } @Override public boolean requiresTags() { return true; } public ShiftReduceParser deepCopy() { // TODO: should we deep copy the options / factory? seems wasteful ShiftReduceParser copy = new ShiftReduceParser(op, featureFactory); copy.copyWeights(this); return copy; } /** * Fill in the current object's weights with the other parser's weights. */ public void copyWeights(ShiftReduceParser other) { transitionIndex.clear(); for (Transition transition : other.transitionIndex) { transitionIndex.add(transition); } knownStates.clear(); knownStates.addAll(other.knownStates); featureWeights.clear(); for (String feature : other.featureWeights.keySet()) { featureWeights.put(feature, new Weight(other.featureWeights.get(feature))); } } public static ShiftReduceParser averageScoredModels(Collection<ScoredObject<ShiftReduceParser>> scoredModels) { if (scoredModels.size() == 0) { throw new IllegalArgumentException("Cannot average empty models"); } System.err.print("Averaging models with scores"); for (ScoredObject<ShiftReduceParser> model : scoredModels) { System.err.print(" " + NF.format(model.score())); } System.err.println(); List<ShiftReduceParser> models = CollectionUtils.transformAsList(scoredModels, new Function<ScoredObject<ShiftReduceParser>, ShiftReduceParser>() { public ShiftReduceParser apply(ScoredObject<ShiftReduceParser> object) { return object.object(); }}); return averageModels(models); } public static ShiftReduceParser averageModels(Collection<ShiftReduceParser> models) { ShiftReduceParser firstModel = models.iterator().next(); ShiftReduceParser copy = new ShiftReduceParser(firstModel.op, firstModel.featureFactory); for (Transition transition : firstModel.transitionIndex) { copy.transitionIndex.add(transition); } for (ShiftReduceParser model : models) { if (!model.transitionIndex.equals(copy.transitionIndex)) { throw new IllegalArgumentException("Can only average models with the same transition index"); } } Set<String> features = Generics.newHashSet(); for (ShiftReduceParser model : models) { for (String feature : model.featureWeights.keySet()) { features.add(feature); } } for (String feature : features) { copy.featureWeights.put(feature, new Weight()); } int numModels = models.size(); for (String feature : features) { for (ShiftReduceParser model : models) { if (!model.featureWeights.containsKey(feature)) { continue; } copy.featureWeights.get(feature).addScaled(model.featureWeights.get(feature), 1.0f / numModels); } } return copy; } @Override public ParserQuery parserQuery() { return new ShiftReduceParserQuery(this); } @Override public Tree apply(List<? extends HasWord> sentence) { ShiftReduceParserQuery pq = new ShiftReduceParserQuery(this); if (pq.parse(sentence)) { return pq.getBestParse(); } return ParserUtils.xTree(sentence); } /** * Iterate over the feature weight map. * For each feature, remove all transitions with score of 0. * Any feature with no transitions left is then removed */ public void condenseFeatures() { Iterator<String> featureIt = featureWeights.keySet().iterator(); while (featureIt.hasNext()) { String feature = featureIt.next(); Weight weights = featureWeights.get(feature); weights.condense(); if (weights.size() == 0) { featureIt.remove(); } } } public void filterFeatures(Set<String> keep) { Iterator<String> featureIt = featureWeights.keySet().iterator(); while (featureIt.hasNext()) { if (!keep.contains(featureIt.next())) { featureIt.remove(); } } } /** * Output some random facts about the parser */ public void outputStats() { System.err.println("Number of known features: " + featureWeights.size()); int numWeights = 0; for (String feature : featureWeights.keySet()) { numWeights += featureWeights.get(feature).size(); } System.err.println("Number of non-zero weights: " + numWeights); int wordLength = 0; for (String feature : featureWeights.keySet()) { wordLength += feature.length(); } System.err.println("Total word length: " + wordLength); System.err.println("Number of transitions: " + transitionIndex.size()); } /** TODO: add an eval which measures transition accuracy? */ @Override public List<Eval> getExtraEvals() { return Collections.emptyList(); } @Override public List<ParserQueryEval> getParserQueryEvals() { if (op.testOptions().recordBinarized == null && op.testOptions().recordDebinarized == null) { return Collections.emptyList(); } List<ParserQueryEval> evals = Generics.newArrayList(); if (op.testOptions().recordBinarized != null) { evals.add(new TreeRecorder(TreeRecorder.Mode.BINARIZED, op.testOptions().recordBinarized)); } if (op.testOptions().recordDebinarized != null) { evals.add(new TreeRecorder(TreeRecorder.Mode.DEBINARIZED, op.testOptions().recordDebinarized)); } return evals; } /** * Returns a transition which might not even be part of the model, * but will hopefully allow progress in an otherwise stuck parse * * TODO: perhaps we want to create an EmergencyTransition class * which indicates that something has gone wrong */ public Transition findEmergencyTransition(State state, List<ParserConstraint> constraints) { if (state.stack.size() == 0) { return null; } // See if there is a constraint whose boundaries match the end // points of the top node on the stack. If so, we can apply a // UnaryTransition / CompoundUnaryTransition if that would solve // the constraint if (constraints != null) { final Tree top = state.stack.peek(); for (ParserConstraint constraint : constraints) { if (ShiftReduceUtils.leftIndex(top) != constraint.start || ShiftReduceUtils.rightIndex(top) != constraint.end - 1) { continue; } if (ShiftReduceUtils.constraintMatchesTreeTop(top, constraint)) { continue; } // found an unmatched constraint that can be fixed with a unary transition // now we need to find a matching state for the transition for (String label : knownStates) { if (constraint.state.matcher(label).matches()) { return ((op.compoundUnaries) ? new CompoundUnaryTransition(Collections.singletonList(label), false) : new UnaryTransition(label, false)); } } } } if (ShiftReduceUtils.isTemporary(state.stack.peek()) && (state.stack.size() == 1 || ShiftReduceUtils.isTemporary(state.stack.pop().peek()))) { return ((op.compoundUnaries) ? new CompoundUnaryTransition(Collections.singletonList(state.stack.peek().value().substring(1)), false) : new UnaryTransition(state.stack.peek().value().substring(1), false)); } if (state.stack.size() == 1) { return null; } if (ShiftReduceUtils.isTemporary(state.stack.peek())) { return new BinaryTransition(state.stack.peek().value().substring(1), BinaryTransition.Side.RIGHT); } if (ShiftReduceUtils.isTemporary(state.stack.pop().peek())) { return new BinaryTransition(state.stack.pop().peek().value().substring(1), BinaryTransition.Side.LEFT); } return null; } /** Convenience method: returns one highest scoring transition, without any ParserConstraints */ public ScoredObject<Integer> findHighestScoringTransition(State state, List<String> features, boolean requireLegal) { Collection<ScoredObject<Integer>> transitions = findHighestScoringTransitions(state, features, requireLegal, 1, null); if (transitions.size() == 0) { return null; } return transitions.iterator().next(); } public Collection<ScoredObject<Integer>> findHighestScoringTransitions(State state, List<String> features, boolean requireLegal, int numTransitions, List<ParserConstraint> constraints) { float[] scores = new float[transitionIndex.size()]; for (String feature : features) { Weight weight = featureWeights.get(feature); if (weight == null) { // Features not in our index are ignored continue; } weight.score(scores); } PriorityQueue<ScoredObject<Integer>> queue = new PriorityQueue<ScoredObject<Integer>>(numTransitions + 1, ScoredComparator.ASCENDING_COMPARATOR); for (int i = 0; i < scores.length; ++i) { if (!requireLegal || transitionIndex.get(i).isLegal(state, constraints)) { queue.add(new ScoredObject<Integer>(i, scores[i])); if (queue.size() > numTransitions) { queue.poll(); } } } return queue; } public static State initialStateFromGoldTagTree(Tree tree) { return initialStateFromTaggedSentence(tree.taggedYield()); } public static State initialStateFromTaggedSentence(List<? extends HasWord> words) { List<Tree> preterminals = Generics.newArrayList(); for (int index = 0; index < words.size(); ++index) { HasWord hw = words.get(index); CoreLabel wordLabel = new CoreLabel(); // Index from 1. Tools downstream from the parser expect that wordLabel.setIndex(index + 1); wordLabel.setValue(hw.word()); if (!(hw instanceof HasTag)) { throw new RuntimeException("Expected tagged words"); } String tag = ((HasTag) hw).tag(); if (tag == null) { throw new RuntimeException("Word is not tagged"); } CoreLabel tagLabel = new CoreLabel(); tagLabel.setValue(((HasTag) hw).tag()); LabeledScoredTreeNode wordNode = new LabeledScoredTreeNode(wordLabel); LabeledScoredTreeNode tagNode = new LabeledScoredTreeNode(tagLabel); tagNode.addChild(wordNode); wordLabel.set(TreeCoreAnnotations.HeadWordAnnotation.class, wordNode); wordLabel.set(TreeCoreAnnotations.HeadTagAnnotation.class, tagNode); tagLabel.set(TreeCoreAnnotations.HeadWordAnnotation.class, wordNode); tagLabel.set(TreeCoreAnnotations.HeadTagAnnotation.class, tagNode); preterminals.add(tagNode); } return new State(preterminals); } public static ShiftReduceOptions buildTrainingOptions(String tlppClass, String[] args) { ShiftReduceOptions op = new ShiftReduceOptions(); op.setOptions("-forceTags", "-debugOutputFrequency", "1"); if (tlppClass != null) { op.tlpParams = ReflectionLoading.loadByReflection(tlppClass); } op.setOptions(args); if (op.trainOptions.randomSeed == 0) { op.trainOptions.randomSeed = (new Random()).nextLong(); System.err.println("Random seed not set by options, using " + op.trainOptions.randomSeed); } return op; } public Treebank readTreebank(String treebankPath, FileFilter treebankFilter) { System.err.println("Loading trees from " + treebankPath); Treebank treebank = op.tlpParams.memoryTreebank(); treebank.loadPath(treebankPath, treebankFilter); System.err.println("Read in " + treebank.size() + " trees from " + treebankPath); return treebank; } public List<Tree> readBinarizedTreebank(String treebankPath, FileFilter treebankFilter) { Treebank treebank = readTreebank(treebankPath, treebankFilter); List<Tree> binarized = binarizeTreebank(treebank, op); System.err.println("Converted trees to binarized format"); return binarized; } public static List<Tree> binarizeTreebank(Treebank treebank, Options op) { TreeBinarizer binarizer = new TreeBinarizer(op.tlpParams.headFinder(), op.tlpParams.treebankLanguagePack(), false, false, 0, false, false, 0.0, false, true, true); BasicCategoryTreeTransformer basicTransformer = new BasicCategoryTreeTransformer(op.langpack()); CompositeTreeTransformer transformer = new CompositeTreeTransformer(); transformer.addTransformer(binarizer); transformer.addTransformer(basicTransformer); treebank = treebank.transform(transformer); HeadFinder binaryHeadFinder = new BinaryHeadFinder(op.tlpParams.headFinder()); List<Tree> binarizedTrees = Generics.newArrayList(); for (Tree tree : treebank) { Trees.convertToCoreLabels(tree); tree.percolateHeadAnnotations(binaryHeadFinder); // Index from 1. Tools downstream expect index from 1, so for // uses internal to the srparser we have to renormalize the // indices, with the result that here we have to index from 1 tree.indexLeaves(1, true); binarizedTrees.add(tree); } return binarizedTrees; } public List<List<Transition>> createTransitionSequences(List<Tree> binarizedTrees) { List<List<Transition>> transitionLists = Generics.newArrayList(); for (Tree tree : binarizedTrees) { List<Transition> transitions = CreateTransitionSequence.createTransitionSequence(tree, op.compoundUnaries); transitionLists.add(transitions); } return transitionLists; } public static void findKnownStates(List<Tree> binarizedTrees, Set<String> knownStates) { for (Tree tree : binarizedTrees) { findKnownStates(tree, knownStates); } } public static void findKnownStates(Tree tree, Set<String> knownStates) { if (tree.isLeaf() || tree.isPreTerminal()) { return; } if (!ShiftReduceUtils.isTemporary(tree)) { knownStates.add(tree.value()); } for (Tree child : tree.children()) { findKnownStates(child, knownStates); } } // TODO: factor out the retagging? public static void redoTags(Tree tree, Tagger tagger) { List<Word> words = tree.yieldWords(); List<TaggedWord> tagged = tagger.apply(words); List<Label> tags = tree.preTerminalYield(); if (tags.size() != tagged.size()) { throw new AssertionError("Tags are not the same size"); } for (int i = 0; i < tags.size(); ++i) { tags.get(i).setValue(tagged.get(i).tag()); } } private static class RetagProcessor implements ThreadsafeProcessor<Tree, Tree> { Tagger tagger; public RetagProcessor(Tagger tagger) { this.tagger = tagger; } public Tree process(Tree tree) { redoTags(tree, tagger); return tree; } public RetagProcessor newInstance() { // already threadsafe return this; } } public static void redoTags(List<Tree> trees, Tagger tagger, int nThreads) { if (nThreads == 1) { for (Tree tree : trees) { redoTags(tree, tagger); } } else { MulticoreWrapper<Tree, Tree> wrapper = new MulticoreWrapper<Tree, Tree>(nThreads, new RetagProcessor(tagger)); for (Tree tree : trees) { wrapper.put(tree); } wrapper.join(); // trees are changed in place } } private static final NumberFormat NF = new DecimalFormat("0.00"); private static final NumberFormat FILENAME = new DecimalFormat("0000"); private static class Update { final List<String> features; final int goldTransition; final int predictedTransition; final float delta; Update(List<String> features, int goldTransition, int predictedTransition, float delta) { this.features = features; this.goldTransition = goldTransition; this.predictedTransition = predictedTransition; this.delta = delta; } } private Pair<Integer, Integer> trainTree(int index, List<Tree> binarizedTrees, List<List<Transition>> transitionLists, List<Update> updates, Oracle oracle) { int numCorrect = 0; int numWrong = 0; Tree tree = binarizedTrees.get(index); // TODO. This training method seems to be working in that it // trains models just like the gold and early termination methods do. // However, it causes the feature space to go crazy. Presumably // leaving out features with low weights or low frequencies would // significantly help with that. Otherwise, not sure how to keep // it under control. if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ORACLE) { State state = ShiftReduceParser.initialStateFromGoldTagTree(tree); while (!state.isFinished()) { List<String> features = featureFactory.featurize(state); ScoredObject<Integer> prediction = findHighestScoringTransition(state, features, true); if (prediction == null) { throw new AssertionError("Did not find a legal transition"); } int predictedNum = prediction.object(); Transition predicted = transitionIndex.get(predictedNum); OracleTransition gold = oracle.goldTransition(index, state); if (gold.isCorrect(predicted)) { numCorrect++; if (gold.transition != null && !gold.transition.equals(predicted)) { int transitionNum = transitionIndex.indexOf(gold.transition); if (transitionNum < 0) { // TODO: do we want to add unary transitions which are // only possible when the parser has gone off the rails? continue; } updates.add(new Update(features, transitionNum, -1, 1.0f)); } } else { numWrong++; int transitionNum = -1; if (gold.transition != null) { transitionNum = transitionIndex.indexOf(gold.transition); // TODO: this can theoretically result in a -1 gold // transition if the transition exists, but is a // CompoundUnaryTransition which only exists because the // parser is wrong. Do we want to add those transitions? } updates.add(new Update(features, transitionNum, predictedNum, 1.0f)); } state = predicted.apply(state); } } else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM) { if (op.trainOptions().beamSize <= 0) { throw new IllegalArgumentException("Illegal beam size " + op.trainOptions().beamSize); } List<Transition> transitions = transitionLists.get(index); PriorityQueue<State> agenda = new PriorityQueue<State>(op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR); State goldState = ShiftReduceParser.initialStateFromGoldTagTree(tree); agenda.add(goldState); int transitionCount = 0; for (Transition goldTransition : transitions) { PriorityQueue<State> newAgenda = new PriorityQueue<State>(op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR); State highestScoringState = null; State highestCurrentState = null; for (State currentState : agenda) { List<String> features = featureFactory.featurize(currentState); Collection<ScoredObject<Integer>> stateTransitions = findHighestScoringTransitions(currentState, features, true, op.trainOptions().beamSize, null); for (ScoredObject<Integer> transition : stateTransitions) { State newState = transitionIndex.get(transition.object()).apply(currentState, transition.score()); newAgenda.add(newState); if (newAgenda.size() > op.trainOptions().beamSize) { newAgenda.poll(); } if (highestScoringState == null || highestScoringState.score() < newState.score()) { highestScoringState = newState; highestCurrentState = currentState; } } } List<String> goldFeatures = featureFactory.featurize(goldState); goldState = goldTransition.apply(goldState, 0.0); // if highest scoring state used the correct transition, no training // otherwise, down the last transition, up the correct if (!goldState.areTransitionsEqual(highestScoringState)) { ++numWrong; int lastTransition = transitionIndex.indexOf(highestScoringState.transitions.peek()); updates.add(new Update(featureFactory.featurize(highestCurrentState), -1, lastTransition, 1.0f)); updates.add(new Update(goldFeatures, transitionIndex.indexOf(goldTransition), -1, 1.0f)); } else { ++numCorrect; } // If the correct state has fallen off the agenda, break boolean found = false; for (State otherState : newAgenda) { if (otherState.areTransitionsEqual(goldState)) { found = true; break; } } if (!found) { break; } agenda = newAgenda; } } else { State state = ShiftReduceParser.initialStateFromGoldTagTree(tree); List<Transition> transitions = transitionLists.get(index); for (Transition transition : transitions) { int transitionNum = transitionIndex.indexOf(transition); List<String> features = featureFactory.featurize(state); int predictedNum = findHighestScoringTransition(state, features, false).object(); Transition predicted = transitionIndex.get(predictedNum); if (transitionNum == predictedNum) { numCorrect++; } else { numWrong++; // TODO: allow weighted features, weighted training, etc updates.add(new Update(features, transitionNum, predictedNum, 1.0f)); } if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.EARLY_TERMINATION && transitionNum != predictedNum) { break; } state = transition.apply(state); } } return Pair.makePair(numCorrect, numWrong); } private class TrainTreeProcessor implements ThreadsafeProcessor<Integer, Pair<Integer, Integer>> { List<Tree> binarizedTrees; List<List<Transition>> transitionLists; List<Update> updates; // this needs to be a synchronized list Oracle oracle; public TrainTreeProcessor(List<Tree> binarizedTrees, List<List<Transition>> transitionLists, List<Update> updates, Oracle oracle) { this.binarizedTrees = binarizedTrees; this.transitionLists = transitionLists; this.updates = updates; this.oracle = oracle; } public Pair<Integer, Integer> process(Integer index) { return trainTree(index, binarizedTrees, transitionLists, updates, oracle); } public TrainTreeProcessor newInstance() { // already threadsafe return this; } } private Triple<List<Update>, Integer, Integer> trainBatch(List<Integer> indices, List<Tree> binarizedTrees, List<List<Transition>> transitionLists, List<Update> updates, Oracle oracle, MulticoreWrapper<Integer, Pair<Integer, Integer>> wrapper) { int numCorrect = 0; int numWrong = 0; if (op.trainOptions.trainingThreads == 1) { for (Integer index : indices) { Pair<Integer, Integer> count = trainTree(index, binarizedTrees, transitionLists, updates, oracle); numCorrect += count.first; numWrong += count.second; } } else { for (Integer index : indices) { wrapper.put(index); } wrapper.join(false); while (wrapper.peek()) { Pair<Integer, Integer> result = wrapper.poll(); numCorrect += result.first; numWrong += result.second; } } return new Triple<List<Update>, Integer, Integer>(updates, numCorrect, numWrong); } private void trainAndSave(List<Pair<String, FileFilter>> trainTreebankPath, Pair<String, FileFilter> devTreebankPath, String serializedPath) { List<Tree> binarizedTrees = Generics.newArrayList(); for (Pair<String, FileFilter> treebank : trainTreebankPath) { binarizedTrees.addAll(readBinarizedTreebank(treebank.first(), treebank.second())); } int nThreads = op.trainOptions.trainingThreads; nThreads = nThreads <= 0 ? Runtime.getRuntime().availableProcessors() : nThreads; Tagger tagger = null; if (op.testOptions.preTag) { Timing retagTimer = new Timing(); tagger = Tagger.loadModel(op.testOptions.taggerSerializedFile); redoTags(binarizedTrees, tagger, nThreads); retagTimer.done("Retagging"); } Timing transitionTimer = new Timing(); List<List<Transition>> transitionLists = createTransitionSequences(binarizedTrees); for (List<Transition> transitions : transitionLists) { // TODO: there is a potential bug here. So far, the assumption // is that all unary transitions which occur at the root only // ever occur at the root. If that assumption doesn't hold for // some treebank, it may occur that a root transition occurs in // the middle of the tree but is marked "isRoot", meaning it can // never actually be used in the middle of the tree. // // A solution to this would be to keep a separate index of all // the transitions which have only ever been seen in the context // of the root. Eg, nothing comes after those transitions // except Finalize or Idle. (That also picks up the unlikely // case of a binary transition being a root transition.) transitionIndex.addAll(transitions); } transitionTimer.done("Converting trees into transition lists"); System.err.println("Number of transitions: " + transitionIndex.size()); findKnownStates(binarizedTrees, knownStates); System.err.println("Known states: " + knownStates); Random random = new Random(op.trainOptions.randomSeed); Treebank devTreebank = null; if (devTreebankPath != null) { devTreebank = readTreebank(devTreebankPath.first(), devTreebankPath.second()); } double bestScore = 0.0; int bestIteration = 0; PriorityQueue<ScoredObject<ShiftReduceParser>> bestModels = null; if (op.trainOptions().averagedModels > 0) { bestModels = new PriorityQueue<ScoredObject<ShiftReduceParser>>(op.trainOptions().averagedModels + 1, ScoredComparator.ASCENDING_COMPARATOR); } List<Integer> indices = Generics.newArrayList(); for (int i = 0; i < binarizedTrees.size(); ++i) { indices.add(i); } Oracle oracle = null; if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ORACLE) { oracle = new Oracle(binarizedTrees, op.compoundUnaries); } List<Update> updates = Generics.newArrayList(); MulticoreWrapper<Integer, Pair<Integer, Integer>> wrapper = null; if (nThreads != 1) { updates = Collections.synchronizedList(updates); wrapper = new MulticoreWrapper<Integer, Pair<Integer, Integer>>(op.trainOptions.trainingThreads, new TrainTreeProcessor(binarizedTrees, transitionLists, updates, oracle)); } IntCounter<String> featureFrequencies = null; if (op.trainOptions().featureFrequencyCutoff > 1) { featureFrequencies = new IntCounter<String>(); } for (int iteration = 1; iteration <= op.trainOptions.trainingIterations; ++iteration) { Timing trainingTimer = new Timing(); int numCorrect = 0; int numWrong = 0; Collections.shuffle(indices, random); for (int start = 0; start < indices.size(); start += op.trainOptions.batchSize) { int end = Math.min(start + op.trainOptions.batchSize, indices.size()); Triple<List<Update>, Integer, Integer> result = trainBatch(indices.subList(start, end), binarizedTrees, transitionLists, updates, oracle, wrapper); numCorrect += result.second; numWrong += result.third; for (Update update : result.first) { for (String feature : update.features) { Weight weights = featureWeights.get(feature); if (weights == null) { weights = new Weight(); featureWeights.put(feature, weights); } weights.updateWeight(update.goldTransition, update.delta); weights.updateWeight(update.predictedTransition, -update.delta); if (featureFrequencies != null) { featureFrequencies.incrementCount(feature, (update.goldTransition >= 0 && update.predictedTransition >= 0) ? 2 : 1); } } } updates.clear(); } trainingTimer.done("Iteration " + iteration); System.err.println("While training, got " + numCorrect + " transitions correct and " + numWrong + " transitions wrong"); outputStats(); double labelF1 = 0.0; if (devTreebank != null) { EvaluateTreebank evaluator = new EvaluateTreebank(op, null, this, tagger); evaluator.testOnTreebank(devTreebank); labelF1 = evaluator.getLBScore(); System.err.println("Label F1 after " + iteration + " iterations: " + labelF1); if (labelF1 > bestScore) { System.err.println("New best dev score (previous best " + bestScore + ")"); bestScore = labelF1; bestIteration = iteration; } else { System.err.println("Failed to improve for " + (iteration - bestIteration) + " iteration(s) on previous best score of " + bestScore); if (op.trainOptions.stalledIterationLimit > 0 && (iteration - bestIteration >= op.trainOptions.stalledIterationLimit)) { System.err.println("Failed to improve for too long, stopping training"); break; } } if (bestModels != null) { bestModels.add(new ScoredObject<ShiftReduceParser>(this.deepCopy(), labelF1)); if (bestModels.size() > op.trainOptions().averagedModels) { bestModels.poll(); } } } if (op.trainOptions().saveIntermediateModels && serializedPath != null && op.trainOptions.debugOutputFrequency > 0) { String tempName = serializedPath.substring(0, serializedPath.length() - 7) + "-" + FILENAME.format(iteration) + "-" + NF.format(labelF1) + ".ser.gz"; saveModel(tempName); // TODO: we could save a cutoff version of the model, // especially if we also get a dev set number for it, but that // might be overkill } } if (wrapper != null) { wrapper.join(); } if (bestModels != null) { if (op.trainOptions().cvAveragedModels && devTreebank != null) { List<ScoredObject<ShiftReduceParser>> models = Generics.newArrayList(); while (bestModels.size() > 0) { models.add(bestModels.poll()); } Collections.reverse(models); double bestF1 = 0.0; int bestSize = 0; for (int i = 1; i <= models.size(); ++i) { System.err.println("Testing with " + i + " models averaged together"); ShiftReduceParser parser = averageScoredModels(models.subList(0, i)); EvaluateTreebank evaluator = new EvaluateTreebank(parser.op, null, parser); evaluator.testOnTreebank(devTreebank); double labelF1 = evaluator.getLBScore(); System.err.println("Label F1 for " + i + " models: " + labelF1); if (labelF1 > bestF1) { bestF1 = labelF1; bestSize = i; } } copyWeights(averageScoredModels(models.subList(0, bestSize))); } else { copyWeights(ShiftReduceParser.averageScoredModels(bestModels)); } } // TODO: perhaps we should filter the features and then get dev // set scores. That way we can merge the models which are best // after filtering. if (featureFrequencies != null) { filterFeatures(featureFrequencies.keysAbove(op.trainOptions().featureFrequencyCutoff)); } condenseFeatures(); if (serializedPath != null) { try { IOUtils.writeObjectToFile(this, serializedPath); } catch (IOException e) { throw new RuntimeIOException(e); } } } public void setOptionFlags(String ... flags) { op.setOptions(flags); } public static ShiftReduceParser loadModel(String path, String ... extraFlags) { ShiftReduceParser parser = null; try { Timing timing = new Timing(); System.err.print("Loading parser from serialized file " + path + " ..."); parser = IOUtils.readObjectFromURLOrClasspathOrFileSystem(path); timing.done(); } catch (IOException e) { throw new RuntimeIOException(e); } catch (ClassNotFoundException e) { throw new RuntimeIOException(e); } if (extraFlags.length > 0) { parser.setOptionFlags(extraFlags); } return parser; } public void saveModel(String path) { try { IOUtils.writeObjectToFile(this, path); } catch (IOException e) { throw new RuntimeIOException(e); } } static final String[] FORCE_TAGS = { "-forceTags" }; public static void main(String[] args) { List<String> remainingArgs = Generics.newArrayList(); List<Pair<String, FileFilter>> trainTreebankPath = null; Pair<String, FileFilter> testTreebankPath = null; Pair<String, FileFilter> devTreebankPath = null; String serializedPath = null; String tlppClass = null; String continueTraining = null; for (int argIndex = 0; argIndex < args.length; ) { if (args[argIndex].equalsIgnoreCase("-trainTreebank")) { if (trainTreebankPath == null) { trainTreebankPath = Generics.newArrayList(); } trainTreebankPath.add(ArgUtils.getTreebankDescription(args, argIndex, "-trainTreebank")); argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1; } else if (args[argIndex].equalsIgnoreCase("-testTreebank")) { testTreebankPath = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank"); argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1; } else if (args[argIndex].equalsIgnoreCase("-devTreebank")) { devTreebankPath = ArgUtils.getTreebankDescription(args, argIndex, "-devTreebank"); argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1; } else if (args[argIndex].equalsIgnoreCase("-serializedPath") || args[argIndex].equalsIgnoreCase("-model")) { serializedPath = args[argIndex + 1]; argIndex += 2; } else if (args[argIndex].equalsIgnoreCase("-tlpp")) { tlppClass = args[argIndex + 1]; argIndex += 2; } else if (args[argIndex].equalsIgnoreCase("-continueTraining")) { continueTraining = args[argIndex + 1]; argIndex += 2; } else { remainingArgs.add(args[argIndex]); ++argIndex; } } String[] newArgs = new String[remainingArgs.size()]; newArgs = remainingArgs.toArray(newArgs); if (trainTreebankPath == null && serializedPath == null) { throw new IllegalArgumentException("Must specify a treebank to train from with -trainTreebank or a parser to load with -serializedPath"); } ShiftReduceParser parser = null; if (trainTreebankPath != null) { System.err.println("Training ShiftReduceParser"); System.err.println("Initial arguments:"); System.err.println(" " + StringUtils.join(args)); if (continueTraining != null) { parser = ShiftReduceParser.loadModel(continueTraining, ArrayUtils.concatenate(FORCE_TAGS, newArgs)); } else { ShiftReduceOptions op = buildTrainingOptions(tlppClass, newArgs); parser = new ShiftReduceParser(op); } parser.trainAndSave(trainTreebankPath, devTreebankPath, serializedPath); } if (serializedPath != null && parser == null) { parser = ShiftReduceParser.loadModel(serializedPath, ArrayUtils.concatenate(FORCE_TAGS, newArgs)); } //parser.outputStats(); if (testTreebankPath != null) { System.err.println("Loading test trees from " + testTreebankPath.first()); Treebank testTreebank = parser.op.tlpParams.memoryTreebank(); testTreebank.loadPath(testTreebankPath.first(), testTreebankPath.second()); System.err.println("Loaded " + testTreebank.size() + " trees"); EvaluateTreebank evaluator = new EvaluateTreebank(parser.op, null, parser); evaluator.testOnTreebank(testTreebank); // System.err.println("Input tree: " + tree); // System.err.println("Debinarized tree: " + query.getBestParse()); // System.err.println("Parsed binarized tree: " + query.getBestBinarizedParse()); // System.err.println("Predicted transition sequence: " + query.getBestTransitionSequence()); } } private static final long serialVersionUID = 1; }