package edu.stanford.nlp.parser.shiftreduce; import java.util.List; import java.util.ListIterator; import edu.stanford.nlp.util.Generics; /** * A second attempt at making an oracle. Instead of always trying to * return the best transition, it simply rearranges the transition * lists after an incorrect transition. If this is not possible, * training will be halted as in the case of early update. * * @author John Bauer */ public class ReorderingOracle { ShiftReduceOptions op; public ReorderingOracle(ShiftReduceOptions op) { this.op = op; } /** * Given a predicted transition and a state, this method rearranges * the list of transitions and returns whether or not training can * continue. */ boolean reorder(State state, Transition chosenTransition, List<Transition> transitions) { if (transitions.size() == 0) { throw new AssertionError(); } Transition goldTransition = transitions.get(0); // If the transition is gold, we are already satisfied. if (chosenTransition.equals(goldTransition)) { transitions.remove(0); return true; } // If the transition should have been a Unary/CompoundUnary // transition and it was something else or a different Unary // transition, see if the transition sequence can be continued // after skipping past the unary if ((goldTransition instanceof UnaryTransition) || (goldTransition instanceof CompoundUnaryTransition)) { transitions.remove(0); return reorder(state, chosenTransition, transitions); } // If the chosen transition was an incorrect Unary/CompoundUnary // transition, skip past it and hope to continue the gold // transition sequence. However, if we have Unary/CompoundUnary // in a row, we have to return false to prevent loops. // Also, if the state stack size is 0, can't keep going if ((chosenTransition instanceof UnaryTransition) || (chosenTransition instanceof CompoundUnaryTransition)) { if (state.transitions.size() > 0) { Transition previous = state.transitions.peek(); if ((previous instanceof UnaryTransition) || (previous instanceof CompoundUnaryTransition)) { return false; } } if (state.stack.size() == 0) { return false; } return true; } if (chosenTransition instanceof BinaryTransition) { if (state.stack.size() < 2) { return false; } if (goldTransition instanceof ShiftTransition) { // Helps, but adds quite a bit of size to the model and only helps a tiny bit return op.trainOptions().oracleBinaryToShift && reorderIncorrectBinaryTransition(transitions); } if (!(goldTransition instanceof BinaryTransition)) { return false; } BinaryTransition chosenBinary = (BinaryTransition) chosenTransition; BinaryTransition goldBinary = (BinaryTransition) goldTransition; if (chosenBinary.isBinarized()) { // Binarized labels only work (for now, at least) if the side // is wrong but the label itself is correct if (goldBinary.isBinarized() && chosenBinary.label.equals(goldBinary.label)) { transitions.remove(0); return true; } else { return false; } } // In all other binarized situations, essentially what has // happened is we added a bracket error, but future brackets can // still wind up being correct transitions.remove(0); return true; } if ((chosenTransition instanceof ShiftTransition) && (goldTransition instanceof BinaryTransition)) { // can't shift at the end of the queue if (state.endOfQueue()) { return false; } // doesn't help, sadly BinaryTransition goldBinary = (BinaryTransition) goldTransition; if (!goldBinary.isBinarized()) { return op.trainOptions().oracleShiftToBinary && reorderIncorrectShiftTransition(transitions); } } return false; } static boolean reorderIncorrectBinaryTransition(List<Transition> transitions) { int shiftCount = 0; ListIterator<Transition> cursor = transitions.listIterator(); do { if (!cursor.hasNext()) { return false; } Transition next = cursor.next(); if (next instanceof ShiftTransition) { ++shiftCount; } else if (next instanceof BinaryTransition) { --shiftCount; if (shiftCount <= 0) { cursor.remove(); } } } while (shiftCount > 0); if (!cursor.hasNext()) { return false; } Transition next = cursor.next(); while ((next instanceof UnaryTransition) || (next instanceof CompoundUnaryTransition)) { cursor.remove(); if (!cursor.hasNext()) { return false; } next = cursor.next(); } // At this point, the rest of the transition sequence should suffice return true; } /** * In this case, we are starting to build a new subtree when instead * we should have been combining existing trees. What we can do is * find the transitions that build up the next subtree in the gold * transition list, figure out how it gets applied to a * BinaryTransition, and make that the next BinaryTransition we * perform after finishing the subtree. If there are multiple * BinaryTransitions in a row, we ignore any associated * UnaryTransitions (unfixable) and try to transition to the final * state. The assumption is that we can't do anything about the * incorrect subtrees any more, so we skip them all. *<br> * Sadly, this does not seem to help - the parser gets worse when it * learns these states */ static boolean reorderIncorrectShiftTransition(List<Transition> transitions) { List<BinaryTransition> leftoverBinary = Generics.newArrayList(); while (transitions.size() > 0) { Transition head = transitions.remove(0); if (head instanceof ShiftTransition) { break; } if (head instanceof BinaryTransition) { leftoverBinary.add((BinaryTransition) head); } } if (transitions.size() == 0 || leftoverBinary.size() == 0) { // honestly this is an error we should probably just throw return false; } int shiftCount = 0; ListIterator<Transition> cursor = transitions.listIterator(); BinaryTransition lastBinary = null; while (cursor.hasNext() && shiftCount >= 0) { Transition next = cursor.next(); if (next instanceof ShiftTransition) { ++shiftCount; } else if (next instanceof BinaryTransition) { --shiftCount; if (shiftCount < 0) { lastBinary = (BinaryTransition) next; cursor.remove(); } } } if (!cursor.hasNext() || lastBinary == null) { // once again, an error. even if the sequence of tree altering // gold transitions ends with a BinaryTransition, there should // be a FinalizeTransition after that return false; } String label = lastBinary.label; if (lastBinary.isBinarized()) { label = label.substring(1); } if (lastBinary.side == BinaryTransition.Side.RIGHT) { // When we finally transition all the binary transitions, we // will want to have the new node be the right head. Therefore, // we add a bunch of temporary binary transitions with a right // head, ending up with a binary transition with a right head for (int i = 0; i < leftoverBinary.size(); ++i) { cursor.add(new BinaryTransition("@" + label, BinaryTransition.Side.RIGHT)); } // use lastBinary.label in case the last transition is temporary cursor.add(new BinaryTransition(lastBinary.label, BinaryTransition.Side.RIGHT)); } else { cursor.add(new BinaryTransition("@" + label, BinaryTransition.Side.LEFT)); for (int i = 0; i < leftoverBinary.size() - 1; ++i) { cursor.add(new BinaryTransition("@" + label, leftoverBinary.get(i).side)); } cursor.add(new BinaryTransition(lastBinary.label, leftoverBinary.get(leftoverBinary.size() - 1).side)); } return true; } }