/** * Copyright (c) 2009, Regents of the University of Colorado All rights * reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. Redistributions in binary * form must reproduce the above copyright notice, this list of conditions and * the following disclaimer in the documentation and/or other materials provided * with the distribution. Neither the name of the University of Colorado at * Boulder nor the names of its contributors may be used to endorse or promote * products derived from this software without specific prior written * permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ package clear.parse; import clear.decode.AbstractDecoder; import clear.decode.OneVsAllDecoder; import clear.dep.DepLib; import clear.dep.DepNode; import clear.dep.DepTree; import clear.dep.srl.SRLArg; import clear.dep.srl.SRLHead; import clear.dep.srl.SRLInfo; import clear.ftr.map.SRLFtrMap; import clear.ftr.xml.SRLFtrXml; import clear.util.tuple.JIntDoubleTuple; import com.carrotsearch.hppc.IntArrayList; import java.util.AbstractCollection; import java.util.ArrayList; import java.util.Arrays; /** * Shift-eager dependency parser. * * @author Jinho D. Choi <b>Last update:</b> 11/6/2010 */ public class SRLParser extends AbstractSRLParser { /** * Label of Shift transition */ static public final String LB_SHIFT = "SH"; /** * Label of NoArc transition */ static public final String LB_NO_ARC = "NA"; /** * For {@link SRLParser#FLAG_TRAIN_BOOST} only. */ protected DepTree d_copy = null; /** * {@link AbstractSRLParser#FLAG_TRAIN_LEXICON}. */ public SRLParser(byte flag, String xmlFile) { super(flag, xmlFile); } /** * {@link AbstractSRLParser#FLAG_TRAIN_INSTANCE}. */ public SRLParser(byte flag, SRLFtrXml xml, String[] lexiconFile) { super(flag, xml, lexiconFile); } /** * {@link AbstractSRLParser#FLAG_PREDICT} or {@link AbstractSRLParser#FLAG_TRAIN_BOOST}. */ public SRLParser(byte flag, SRLFtrXml xml, SRLFtrMap[] map, AbstractDecoder[] decoder) { super(flag, xml, map, decoder); } /** * Initializes member variables. */ private void init(DepTree tree) { tree.setSubcat(); d_tree = tree; i_beta = tree.nextPredicateId(0); i_lambda = i_beta - 1; i_dir = DIR_LEFT; ls_args = new ArrayList<>(); ls_argn = new ArrayList<>(); if (i_flag == FLAG_PREDICT || i_flag == FLAG_TRAIN_BOOST) { d_copy = tree.clone(); d_tree.clearSRLHeads(); } } /** * Parses * <code>tree</code>. */ @Override public void parse(DepTree tree) { init(tree); while (i_beta < tree.size()) { if (i_lambda <= 0 || i_lambda >= tree.size()) { shift(); } else if (i_flag == FLAG_PREDICT) { predict(); } else if (i_flag == FLAG_TRAIN_BOOST) { trainConditional(); } else { train(); } } } /** * Trains the dependency tree ({@link SRLParser#d_tree}). */ private void train() { String label = getGoldLabel(d_tree); if (label.equals(LB_NO_ARC)) { noArc(1d); } else { yesArc(label, 1d); } } /** * Predicts dependencies. */ private void predict() { predictAux(getFeatureArray()); } private void predictAux(IntArrayList ftr) { SRLFtrMap map = getFtrMap(); OneVsAllDecoder dec = getDecoder(); JIntDoubleTuple res = dec.predict(ftr); String label = (res.i < 0) ? LB_NO_ARC : map.indexToLabel(res.i); // res.d = AbstractModel.logistic(res.d); if (label.equals(LB_NO_ARC)) { noArc(res.d); } else { yesArc(label, res.d); } } private void trainConditional() { String gLabel = getGoldLabel(d_copy); IntArrayList ftr = getFeatureArray(); saveInstance(gLabel, ftr); predictAux(ftr); } private String getGoldLabel(DepTree tree) { DepNode lambda = tree.get(i_lambda); String label; if ((label = lambda.getLabel(i_beta)) != null) { return label; } else { return LB_NO_ARC; } } /** * Performs a shift transition. * * @param isDeterministic true if this is called for a deterministic-shift. */ private void shift() { if (i_dir == DIR_RIGHT) { shiftRight(); i_beta = d_tree.nextPredicateId(i_beta); } i_dir *= -1; i_lambda = i_beta + i_dir; } /** * Called from {@link SRLParser#shift()} for {@link AbstractSRLParser#DIR_RIGHT}. */ private void shiftRight() { if (i_flag == FLAG_PREDICT || i_flag == FLAG_TRAIN_BOOST) { addArgs(ls_args); } ls_args.clear(); ls_argn.clear(); } private void addArgs(ArrayList<SRLArg> seq) { for (SRLArg arg : seq) { d_tree.get(arg.argId).addSRLHead(i_beta, arg.label, arg.score); } } /** * Performs a no-arc transition. */ private void noArc(double score) { trainInstance(LB_NO_ARC); i_lambda += i_dir; } private String yesArc(String label, double score) { trainInstance(label); SRLArg arg = new SRLArg(i_lambda, label, score); ls_args.add(arg); if (label.matches("A\\d")) { ls_argn.add(label); } i_lambda += i_dir; return null; } private void trainInstance(String label) { if (i_flag == FLAG_TRAIN_LEXICON) { addTags(label); } else if (i_flag == FLAG_TRAIN_INSTANCE) { saveInstance(label, getFeatureArray()); } } // ---------------------------- getFtr*() ---------------------------- @Override protected void addLexica(SRLFtrMap map) { addNgramLexica(map); addSetLexica(map, 0, d_tree.getDeprelDepSet(i_beta)); addStrLexica(map, 1, getPredArg()); } protected void addSetLexica(SRLFtrMap map, int ftrId, AbstractCollection<String> ftrs) { for (String ftr : ftrs) { map.addExtra(ftrId, ftr); } } protected void addStrLexica(SRLFtrMap map, int ftrId, String ftr) { if (ftr != null) { map.addExtra(ftrId, ftr); } } protected String getPredArg() { if (i_dir == DIR_RIGHT) { return null; } SRLInfo info = d_tree.get(i_lambda).srlInfo; if (!info.heads.isEmpty()) { for (int i = info.heads.size() - 1; i >= 0; i--) { SRLHead head = info.heads.get(i); if (head.headId < i_beta) { DepNode pred = d_tree.get(head.headId); return pred.lemma + "_" + head.label; } } } return null; } protected IntArrayList getFeatureArray() { // add features IntArrayList arr = new IntArrayList(); int idx[] = {1}; SRLFtrMap map = getFtrMap(); addNgramFeatures(arr, idx, map); addBinaryFeatures(arr, idx); addDistanceFeature(arr, idx); addSetFeatures(arr, idx, map, 0, d_tree.getDeprelDepSet(i_beta)); addStrFeatures(arr, idx, map, 1, getPredArg()); return arr; } protected void addBinaryFeatures(IntArrayList arr, int[] idx) { DepNode lambda = d_tree.get(i_lambda); DepNode beta = d_tree.get(i_beta); if (lambda.headId == i_beta) { arr.add(idx[0]); } else if (beta.headId == i_lambda) { arr.add(idx[0] + 1); } else if (d_tree.isAncestor(beta, lambda)) { arr.add(idx[0] + 2); // for out-of-domain } while (DepLib.M_VC.matcher(beta.deprel).matches()) { beta = d_tree.get(beta.headId); if (d_tree.getDeprelDepSet(beta.id).contains(DepLib.DEPREL_SBJ)) { arr.add(idx[0] + 3); break; } } idx[0] += 4; } protected void addDistanceFeature(IntArrayList arr, int[] idx) { int dist = Math.abs(i_beta - i_lambda); if (dist <= 5) { dist = 0; } else if (dist <= 10) { dist = 1; } else { dist = 2; } arr.add(idx[0] + dist); idx[0] += 3; } protected void addSetFeatures(IntArrayList arr, int[] idx, SRLFtrMap map, int ftrId, AbstractCollection<String> ftrs) { IntArrayList list = new IntArrayList(); int i; for (String ftr : ftrs) { if ((i = map.extraToIndex(ftrId, ftr)) >= 0) { list.add(idx[0] + i); } } int[] tmp = list.toArray(); Arrays.sort(tmp); arr.add(tmp, 0, tmp.length); idx[0] += map.n_extra[ftrId]; } protected void addStrFeatures(IntArrayList arr, int[] idx, SRLFtrMap map, int ftrId, String ftr) { if (ftr != null) { int index = map.extraToIndex(ftrId, ftr); if (index >= 0) { arr.add(idx[0] + index); } } idx[0] += map.n_extra[ftrId]; } // ==================================== SHIFT ==================================== /* * private void predictBest() { if (i_lambda >= d_tree.size()) // * right-shift { dynamicPopulate(); return; } else if (i_lambda <= 0) // * left-shift { shift(); predictBest(); return; } * * int iBeta = i_beta; int iLambda = i_lambda; byte iDir = i_dir; * ArrayList<SRLArg> lsArgs = new ArrayList<SRLArg>(ls_args); * ArrayList<String> lsArgn = new ArrayList<String>(ls_argn); * * SRLFtrMap map = getFtrMap(); OneVsAllDecoder dec = getDecoder(); * JIntDoubleTuple[] res = dec.predictAll(getFeatureArray()); String label; * double score; * * for (int i=0; i<K; i++) { if (i != 0) { if (res[i].d < THRESHOLD) break; * * i_beta = iBeta; i_lambda = iLambda; i_dir = iDir; ls_args = new * ArrayList<SRLArg>(lsArgs); ls_argn = new ArrayList<String>(lsArgn); } * * label = map.indexToLabel(res[i].i); score = * AbstractModel.logistic(res[i].d); * * if (label.equals(LB_NO_ARC)) noArc(score); else if ((label = * yesArc(label, score)) != null) { dynamicAttach(label); continue; } * * predictBest(); } } * * private void dynamicPick() { ArrayList<ArrayList<SRLArg>> list = * m_dynamic.get(KEY_REL); JObjectDoubleTuple<ArrayList<SRLArg>> max = new * JObjectDoubleTuple<ArrayList<SRLArg>>(null, -1); DepNode beta = * d_tree.get(i_beta); double score; * * for (ArrayList<SRLArg> seq : list) { score = getScore(beta, seq); if * (score > max.value) max.set(seq, score); } * * addArgs(max.object); } * * private double getScore(DepNode pred, ArrayList<SRLArg> lsArgs) { double * score = 1; * * for (SRLArg arg : lsArgs) score *= arg.score; * * return score; } * * private void dynamicPopulate() { addDynamicList(KEY_REL, ls_args); int * size = ls_args.size(); * * for (int i=0; i<size; i++) { SRLArg arg = ls_args.get(i); * addDynamicList(arg.toString(), new ArrayList<SRLArg>(ls_args.subList(i, * size))); } } * * private void dynamicAttach(String key) { ArrayList<ArrayList<SRLArg>> * list = m_dynamic.get(key); * * for (ArrayList<SRLArg> ls : list) { ArrayList<SRLArg> tmp = new * ArrayList<SRLArg>(ls_args); * * tmp.addAll(ls); addDynamicList(KEY_REL, tmp); } } * * private void addDynamicList(String key, ArrayList<SRLArg> seq) { * ArrayList<ArrayList<SRLArg>> list; * * if (m_dynamic.containsKey(key)) { list = m_dynamic.get(key); } else { * list = new ArrayList<ArrayList<SRLArg>>(); m_dynamic.put(key, list); } * * list.add(seq); } * * protected String seqToString(ArrayList<SRLArg> seq) { StringBuilder build * = new StringBuilder(); build.append(d_tree.get(i_beta).form); * NumberFormat formatter = new DecimalFormat("#0.0000"); * * for (SRLArg arg : seq) { build.append(" "); build.append(arg.toString()); * build.append(":"); build.append(formatter.format(arg.score)); } * * return build.toString(); } * * protected boolean isShift(double score) { if (b_checkShift) b_checkShift * = false; else return false; * * double[] arr = getShiftArray(score); if (arr == null) return false; * * double sum = 0; double[] weight = {0.23017389186528628, * -1.4706302173489318, -0.1499877207617069, -1.180854067164068, * -0.38980112947136575, 0.10158032395092477, -2.073205316928537}; for (int * i=0; i<weight.length; i++) sum += (weight[i] * arr[i]); sum *= -1; return * (sum >= 1.5); } * * protected double[] getShiftArray(double score) { DepNode beta = * d_tree.get(i_beta); * * ObjectDoubleOpenHashMap<String> mProb1a = p_prob.get1aProbMap(beta, * i_dir); ObjectDoubleOpenHashMap<String> mProb2a = * p_prob.get2aProbMap(beta, s_prevArgA, i_dir); * ObjectDoubleOpenHashMap<String> mProb2n = p_prob.get2aProbMap(beta, * s_prevArgN, i_dir); * * JDoubleDoubleTuple prob1a = getEndArgProb(mProb1a); JDoubleDoubleTuple * prob2a = getEndArgProb(mProb2a); JDoubleDoubleTuple prob2n = * getEndArgProb(mProb2n); * * double[] arr = {prob1a.d1, prob2a.d1, prob2n.d1, -prob1a.d2, -prob2a.d2, * -prob2n.d2, -score}; double prob = 0; for (double d : arr) prob += d; * * return (prob == 0) ? null : arr; } * * private JDoubleDoubleTuple getEndArgProb(ObjectDoubleOpenHashMap<String> * mPred) { JDoubleDoubleTuple max = new JDoubleDoubleTuple(0, 0); if (mPred * == null) return max; * * String label; double prob; * * for (ObjectCursor<String> cur : mPred.keySet()) { if ((label = * cur.value).equals(SRLProb.ARG_END)) { max.d1 = mPred.get(label); } else * if (!s_args.contains(label)) { if ((prob = mPred.get(label)) > max.d2) * max.d2 = prob; } } * * return max; } * * protected String getShiftEquation(double[] arr, boolean isShift) { // * DecimalFormat format = new DecimalFormat("#0.0000"); StringBuilder build * = new StringBuilder(); * * build.append(d_tree.get(i_beta).form); build.append(" "); * build.append(d_tree.get(i_lambda).form); build.append(" "); * * if (isShift) { build.append( "1"); } else { build.append("-1"); } * * for (int i=0; i<arr.length; i++) { build.append(" "); * build.append((i+1)); build.append(":"); build.append(arr[i]); } * * return build.toString(); } * * protected boolean isShift(DepTree tree) { for (int i=i_lambda; 0<i && * i<tree.size(); i+=i_dir) { if (tree.get(i).isSRLHead(i_beta)) return * false; } * * return true; } */ }