/** * Copyright (c) 2009, Regents of the University of Colorado All rights * reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. Redistributions in binary * form must reproduce the above copyright notice, this list of conditions and * the following disclaimer in the documentation and/or other materials provided * with the distribution. Neither the name of the University of Colorado at * Boulder nor the names of its contributors may be used to endorse or promote * products derived from this software without specific prior written * permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ package clear.parse; import clear.decode.AbstractMultiDecoder; import clear.decode.OneVsAllDecoder; import clear.dep.DepLib; import clear.dep.DepNode; import clear.dep.DepTree; import clear.ftr.map.DepFtrMap; import clear.ftr.xml.DepFtrXml; import clear.util.tuple.JIntDoubleTuple; import clear.util.tuple.JObjectDoubleTuple; import com.carrotsearch.hppc.IntArrayList; import java.util.ArrayList; /** * Shift-eager dependency parser. * * @author Jinho D. Choi <b>Last update:</b> 4/12/2011 */ public class ShiftEagerParser extends AbstractDepParser { /** * Label of Shift transition */ static public final String LB_SHIFT = "SH"; /** * Label of No-Arc transition */ static public final String LB_NO_ARC = "NA"; /** * Label of Left-Arc transition */ static public final String LB_LEFT_ARC = "LA"; /** * Label of Right-Arc transition */ static public final String LB_RIGHT_ARC = "RA"; /** * Delimiter between transition and dependency label */ static public final String LB_DELIM = "-"; /** * {@link AbstractDepParser#FLAG_TRAIN_BOOST} only. */ protected DepTree d_copy = null; /** * {@link ShiftEagerParser#FLAG_PRINT_TRANSITION} or {@link ShiftEagerParser#FLAG_TRAIN_LEXICON}. */ public ShiftEagerParser(byte flag, String filename) { super(flag, filename); } /** * {@link ShiftEagerParser#FLAG_TRAIN_INSTANCE}. */ public ShiftEagerParser(byte flag, DepFtrXml xml, String lexiconFile) { super(flag, xml, lexiconFile); } /** * {@link ShiftEagerParser#FLAG_PREDICT} or {@link ShiftEagerParser#FLAG_TRAIN_BOOST}. */ public ShiftEagerParser(byte flag, DepFtrXml xml, DepFtrMap map, AbstractMultiDecoder decoder) { super(flag, xml, map, decoder); } /** * Initializes all pointers. */ protected void init(DepTree tree) { preProcess(tree); d_tree = tree; i_lambda = 0; i_beta = 1; prev_trans = new ArrayList<>(); if (i_flag == FLAG_PRINT_TRANSITION) { printTransition("", ""); } else if (i_flag == FLAG_TRAIN_BOOST) { d_copy = tree.clone(); } } /** * Parses the dependency tree. */ @Override public void parse(DepTree tree) { init(tree); int size = tree.size(); while (i_beta < size) // beta is not empty { if (i_lambda == -1) // lambda_1 is empty: deterministic shift { shift(true); continue; } else if (i_flag == FLAG_PREDICT) { predict(); } else if (i_flag == FLAG_TRAIN_BOOST) { trainBoost(); } else { train(); } d_tree.n_trans++; } if (i_flag == FLAG_PRINT_TRANSITION) { f_out.println(); } else if (i_flag == FLAG_PREDICT) { postProcess(LB_LEFT_ARC, LB_RIGHT_ARC); } else if (i_flag == FLAG_TRAIN_BOOST) { postProcessBoost(); } } /** * Trains a dependency tree. */ private void train() { DepNode lambda = d_tree.get(i_lambda); DepNode beta = d_tree.get(i_beta); if (lambda.headId == beta.id) { leftArc(lambda, beta, lambda.deprel, 1d); } else if (lambda.id == beta.headId) { rightArc(lambda, beta, beta.deprel, 1d); } else if (isShift(d_tree)) { shift(false); } else { noArc(); } } /** * This method is called from {@link ShiftEagerParser#train()}. * * @return true if non-deterministic shift needs to be performed */ protected boolean isShift(DepTree tree) { DepNode beta = tree.get(i_beta); for (int i = i_lambda; i >= 0; i--) { DepNode lambda = tree.get(i); if (lambda.headId == beta.id || lambda.id == beta.headId) { return false; } } return true; } /** * Predicts dependencies. */ private void predict() { predictAux(getFeatureArray()); } private void trainBoost() { String gLabel = getGoldLabel(d_copy); IntArrayList ftr = getFeatureArray(); saveInstance(gLabel, ftr); predictAux(ftr); } private String predictAux(IntArrayList ftr) { JIntDoubleTuple res; res = c_dec.predict(ftr); String label = (res.i < 0) ? LB_NO_ARC : t_map.indexToLabel(res.i); int index = label.indexOf(LB_DELIM); String trans = (index > 0) ? label.substring(0, index) : label; String deprel = (index > 0) ? label.substring(index + 1) : ""; DepNode lambda = d_tree.get(i_lambda); DepNode beta = d_tree.get(i_beta); if (trans.equals(LB_LEFT_ARC) && !d_tree.isAncestor(lambda, beta) && lambda.id != DepLib.ROOT_ID) { leftArc(lambda, beta, deprel, res.d); } else if (trans.equals(LB_RIGHT_ARC) && !d_tree.isAncestor(beta, lambda)) { rightArc(lambda, beta, deprel, res.d); } else if (trans.equals(LB_SHIFT)) { shift(false); } else { noArc(); } return label; } private String getGoldLabel(DepTree tree) { DepNode lambda = tree.get(i_lambda); DepNode beta = tree.get(i_beta); if (lambda.headId == beta.id) { return LB_LEFT_ARC + LB_DELIM + lambda.deprel; } else if (lambda.id == beta.headId) { return LB_RIGHT_ARC + LB_DELIM + beta.deprel; } else if (isShift(tree)) { return LB_SHIFT; } else { return LB_NO_ARC; } } /** * Predicts dependencies for tokens that have not found their heads during * parsing. */ protected void postProcess(String leftLabels, String rightLabels) { int currId, maxId, i, n = d_tree.size(); JObjectDoubleTuple<String> max; DepNode curr, node; for (currId = 1; currId < n; currId++) { curr = d_tree.get(currId); if (curr.hasHead) { continue; } max = new JObjectDoubleTuple<>(null, -1000); maxId = -1; for (i = currId - 1; i >= 0; i--) { node = d_tree.get(i); if (d_tree.isAncestor(curr, node)) { continue; } maxId = getMaxHeadId(curr, node, maxId, max, rightLabels); } for (i = currId + 1; i < d_tree.size(); i++) { node = d_tree.get(i); if (d_tree.isAncestor(curr, node)) { continue; } maxId = getMaxHeadId(curr, node, maxId, max, leftLabels); } if (maxId != -1) { curr.setHead(maxId, max.object, max.value); } } } /** * This method is called from {@link ShiftEagerParser#postProcess()}. */ protected int getMaxHeadId(DepNode curr, DepNode head, int maxId, JObjectDoubleTuple<String> max, String sTrans) { if (curr.id < head.id) { i_lambda = curr.id; i_beta = head.id; } else { i_lambda = head.id; i_beta = curr.id; } JIntDoubleTuple[] aRes; JIntDoubleTuple res; String label, trans; int index; aRes = ((OneVsAllDecoder) c_dec).predictAll(getFeatureArray()); if (curr.id < head.id && t_map.indexToLabel(aRes[0].i).equals(LB_SHIFT)) { return maxId; } for (int i = 0; i < aRes.length; i++) { res = aRes[i]; label = t_map.indexToLabel(res.i); index = label.indexOf(LB_DELIM); if (index == -1) { continue; } trans = label.substring(0, index); if (trans.matches(sTrans)) { if (max.value < res.d) { String deprel = label.substring(index + 1); max.set(deprel, res.d); maxId = head.id; } break; } } return maxId; } private void postProcessBoost() { int currId, n = d_tree.size(); DepNode curr; for (currId = 1; currId < n; currId++) { if (d_tree.get(currId).hasHead) { continue; } curr = d_copy.get(currId); i_lambda = currId - 1; i_beta = currId; if (isShift(d_copy)) { saveInstance(LB_SHIFT, getFeatureArray()); } if (currId < curr.headId) { i_lambda = currId; i_beta = curr.headId; } else { i_lambda = curr.headId; i_beta = currId; } saveInstance(getGoldLabel(d_copy), getFeatureArray()); } } /** * Performs a shift transition. * * @param isDeterministic true if this is called for a deterministic-shift. */ protected void shift(boolean isDeterministic) { if (!isDeterministic) { trainInstance(LB_SHIFT); } i_lambda = i_beta++; prev_trans.clear(); if (i_flag == FLAG_PRINT_TRANSITION) { if (isDeterministic) { printTransition("DT-SHIFT", ""); } else { printTransition("NT-SHIFT", ""); } } } /** * Performs a no-arc transition. */ protected void noArc() { trainInstance(LB_NO_ARC); i_lambda--; prev_trans.add(LB_NO_ARC); if (i_flag == FLAG_PRINT_TRANSITION) { printTransition("NO-ARC", ""); } } /** * Performs a left-arc transition. * * @param lambda lambda_1[0] * @param beta beta[0] * @param deprel dependency label between * <code>lambda</code> and * <code>beta</code> * @param score dependency score between * <code>lambda</code> and * <code>beta</code> */ protected void leftArc(DepNode lambda, DepNode beta, String deprel, double score) { String label = LB_LEFT_ARC + LB_DELIM + deprel; trainInstance(label); lambda.setHead(beta.id, deprel, score); if (beta.leftMostDep == null || lambda.id < beta.leftMostDep.id) { beta.leftMostDep = lambda; } i_lambda--; prev_trans.add(label); if (i_flag == FLAG_PRINT_TRANSITION) { printTransition("LEFT-ARC", lambda.id + " <-" + deprel + "- " + beta.id); } } /** * Performs a right-arc transition. * * @param lambda lambda_1[0] * @param beta beta[0] * @param deprel dependency label between lambda_1[0] and beta[0] * @param score dependency score between lambda_1[0] and beta[0] */ protected void rightArc(DepNode lambda, DepNode beta, String deprel, double score) { String label = LB_RIGHT_ARC + LB_DELIM + deprel; trainInstance(label); beta.setHead(lambda.id, deprel, score); if (lambda.rightMostDep == null || lambda.rightMostDep.id < beta.id) { lambda.rightMostDep = beta; } i_lambda--; prev_trans.add(label); if (i_flag == FLAG_PRINT_TRANSITION) { printTransition("RIGHT-ARC", lambda.id + " -" + deprel + "-> " + beta.id); } } }