/** * Copyright 2014, Emory University * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package edu.emory.clir.clearnlp.dependency; import java.util.ArrayList; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Set; import java.util.StringJoiner; import java.util.function.Function; import com.carrotsearch.hppc.cursors.IntCursor; import edu.emory.clir.clearnlp.collection.set.IntHashSet; import edu.emory.clir.clearnlp.srl.SRLTree; import edu.emory.clir.clearnlp.util.StringUtils; import edu.emory.clir.clearnlp.util.arc.DEPArc; import edu.emory.clir.clearnlp.util.arc.SRLArc; import edu.emory.clir.clearnlp.util.constant.StringConst; /** * @since 1.0.0. * @author Jinho D. Choi ({@code jinho.choi@emory.edu}) */ public class DEPTree implements Iterable<DEPNode> { private DEPNode[] d_tree; private int n_size; // ====================================== Constructors ====================================== /** * Create a new tree where the root is automatically added at the top * @param size */ public DEPTree(int size) { init(size); } /** * Create a DEPTree from a list of DEPNodes * @param list */ public <T>DEPTree(List<T> list) { int i, size = list.size(); init(size); T item; for (i=0; i<size; i++) { item = list.get(i); if (item instanceof DEPNode) add((DEPNode)list.get(i)); else if (item instanceof String) add(new DEPNode(i+1, (String)item)); } } /** * Create a DEPTree from an old DEPTree * @param oTree */ public DEPTree(DEPTree oTree) { DEPNode oNode, nNode, oHead, nHead; int i, size = oTree.size(); init(size-1); for (i=1; i<size; i++) { oNode = oTree.get(i); nNode = new DEPNode(oNode); add(nNode); if (oNode.getSecondaryHeadArcList() != null) nNode.initSecondaryHeads(); if (oNode.getSemanticHeadArcList() != null) nNode.initSemanticHeads(); } for (i=1; i<size; i++) { oNode = oTree.get(i); oHead = oNode.getHead(); nNode = get(i); nHead = get(oHead.getID()); if (oNode.getSecondaryHeadArcList() != null) { for (DEPArc xHead : oNode.getSecondaryHeadArcList()) { oHead = xHead.getNode(); nNode.addSecondaryHead(new DEPArc(get(oHead.getID()), xHead.getLabel())); } } if (oNode.getSemanticHeadArcList() != null) { for (SRLArc sHead : oNode.getSemanticHeadArcList()) { oHead = sHead.getNode(); nNode.addSemanticHead(new SRLArc(get(oHead.getID()), sHead.getLabel(), sHead.getNumberedArgumentTag())); } } nNode.setHead(nHead, oNode.getLabel()); } } /** * Create a new DEPTree with root DEPNode * @param size */ private void init(int size) { d_tree = new DEPNode[size+1]; DEPNode root = new DEPNode(); root.initRoot(); n_size = 0; add(root); } // ====================================== Tree operations ====================================== /** * Return the DEPNode of a specific ID if exists * @param id * @return */ public DEPNode get(int id) { return (0 <= id && id < n_size) ? d_tree[id] : null; } /** * Add a DEPNode to the DEPTree * @param node */ public void add(DEPNode node) { increaseSize(); d_tree[n_size++] = node; } /** * Check if number of DEPNodes in the DEPTree has reached max size * if reached max size then increase DEPTree by 5 */ private void increaseSize() { if (n_size == d_tree.length) { DEPNode[] nTree = new DEPNode[n_size+5]; System.arraycopy(d_tree, 0, nTree, 0, n_size); d_tree = nTree; } } /** * Return the number of DEPNodes in the DEPTree * @return */ public int size() { return n_size; } /** * Remove the DEPNode with the specific ID * @param id */ public void remove(int id) { if (id <= 0 || id >= n_size) throw new IndexOutOfBoundsException(); try { d_tree[id].setHead(null, null); n_size--; for (int i=id; i<n_size; i++) { d_tree[i] = d_tree[i+1]; d_tree[i].setID(i); } } catch (IndexOutOfBoundsException e) {e.printStackTrace();} } /** * Inserts the specific node at the specific ID of this DEPTree. * @param id * @param node */ public void insert(int id, DEPNode node) { if (id <= 0 || id > n_size) throw new IndexOutOfBoundsException(); try { increaseSize(); for (int i=n_size; i>id; i--) { d_tree[i] = d_tree[i-1]; d_tree[i].setID(i); } d_tree[id] = node; node.setID(id); n_size++; } catch (IndexOutOfBoundsException e) {e.printStackTrace();} } /** * Reset all DEPNodes in DEPTree from beginning ID (inclusive) from ID = 1 */ public void resetNodeIDs() { resetNodeIDs(1); } /** * Starting from a given ID we reset all the IDs to be in ascending order * @param beginID */ private void resetNodeIDs(int beginID) { int i, size = size(); for (i=beginID; i<size; i++) get(i).setID(i); } // ====================================== Initialization ====================================== /** * Initialize all secondary heads of this DEPTree */ public void initSecondaryHeads() { for (DEPNode node : this) node.initSecondaryHeads(); } /** * Initialize all semantic heads of this DEPTree */ public void initSemanticHeads() { for (DEPNode node : this) node.initSemanticHeads(); } // ====================================== Dependency ====================================== /** * Return a list of all the root DEPNodes in this DEPTree * @return */ public List<DEPNode> getRoots() { List<DEPNode> roots = new ArrayList<>(); DEPNode root = get(DEPLib.ROOT_ID); for (DEPNode node : this) { if (node.isDependentOf(root)) roots.add(node); } return roots; } /** * Return the first DEPNode of the DEPTree that is not null * @return */ public DEPNode getFirstRoot() { DEPNode root = get(DEPLib.ROOT_ID); for (DEPNode node : this) { if (node.isDependentOf(root)) return node; } return null; } /** * Return Total Count, LAS, UAS scores in an array{3} * @param goldHeads * @param evalPunct * @return */ public int[] getScoreCounts(DEPArc[] goldHeads, boolean evalPunct) { int i, las = 0, uas = 0, total = 0, size = size(); DEPNode node; DEPArc g; for (i=1; i<size; i++) { node = get(i); if (!evalPunct && StringUtils.containsPunctuationOnly(node.getSimplifiedWordForm())) continue; g = goldHeads[i]; total++; if (node.isDependentOf(get(g.getNode().getID()))) { uas++; if (node.isLabel(g.getLabel())) las++; } } return new int[]{total, las, uas}; } /** Convert this DEPTree into a projective tree. */ public void projectivize() { IntHashSet ids = new IntHashSet(); DEPNode nonProj, head, gHead; int i, size = size(); String dir; for (i=1; i<size; i++) ids.add(i); while ((nonProj = getSmallestNonProjectiveArc(ids)) != null) { head = nonProj.getHead(); gHead = head.getHead(); dir = (head.getID() < gHead.getID()) ? DEPLib.NPROJ_LEFT: DEPLib.NPROJ_RIGHT; nonProj.setHead(gHead, nonProj.getLabel()+dir+head.getLabel()); } } /** Called by {@link #projectivize(String)}. */ private DEPNode getSmallestNonProjectiveArc(IntHashSet ids) { IntHashSet remove = new IntHashSet(); DEPNode wk, nonProj = null; int id, np, max = 0; for (IntCursor cur : ids) { id = cur.value; wk = get(id); np = getNonProjectiveDistance(wk); if (np == 0) { remove.add(id); } else if (np > max) { nonProj = wk; max = np; } } ids.removeAll(remove); return nonProj; } /** @return > 0 if w_k is non-projective. */ private int getNonProjectiveDistance(DEPNode node) { DEPNode head = node.getHead(); if (head == null) return 0; DEPNode wj; int bId, eId, j; if (node.getID() < head.getID()) { bId = node.getID(); eId = head.getID(); } else { bId = head.getID(); eId = node.getID(); } for (j=bId+1; j<eId; j++) { wj = get(j); if (!wj.isDescendantOf(head)) return Math.abs(head.getID() - node.getID()); } return 0; } /** * Return true if it is non projective tree * @return */ public boolean isNonProjective() { DEPNode head, wj, wh; int bId, eId, j; for (DEPNode node : this) { head = node.getHead(); if (head == null) continue; if (node.getID() < head.getID()) { bId = node.getID(); eId = head.getID(); } else { bId = head.getID(); eId = node.getID(); } for (j=bId+1; j<eId; j++) { wj = get(j); wh = wj.getHead(); if (wh != null && (wh.getID() < bId || wh.getID() > eId)) return true; } } return false; } /** * Returns true if this DEPTree contains a cycle * @return */ public boolean containsCycle() { for (DEPNode node : this) { if (node.getHead().isDescendantOf(node)) return true; } return false; } // --------------------------------- Semantics --------------------------------- /** * Return the next Semantic Head of this ID * @param beginID * @return */ public DEPNode getNextSemanticHead(int beginID) { int i, size = size(); DEPNode node; for (i=beginID+1; i<size; i++) { node = get(i); if (node.isSemanticHead()) return node; } return null; } /** * Return true if this DEPTree contains a semantic head * @return */ public boolean containsSemanticHead() { for (DEPNode node : this) { if (node.isSemanticHead()) return true; } return false; } /** * Return List of Semantic Role Label Arc * @return */ public List<List<SRLArc>> getArgumentList() { List<List<SRLArc>> list = new ArrayList<>(); int i, size = size(); List<SRLArc> args; for (i=0; i<size; i++) list.add(new ArrayList<SRLArc>()); for (DEPNode node : this) { for (SRLArc arc : node.getSemanticHeadArcList()) { args = list.get(arc.getNode().getID()); args.add(new SRLArc(node, arc.getLabel(), arc.getNumberedArgumentTag())); } } return list; } /** * @return A semantic tree representing a predicate-argument structure of the specific token if exists; otherwise, {@code null}. * @param predicateID the node ID of a predicate. */ public SRLTree getSRLTree(int predicateID) { return getSRLTree(get(predicateID)); } /** * Return all predicate argument structures of this DEPTree * @param predicate * @return */ public SRLTree getSRLTree(DEPNode predicate) { if (!predicate.isSemanticHead()) return null; SRLTree tree = new SRLTree(predicate); SRLArc arc; for (DEPNode node : this) { arc = node.getSemanticHeadArc(predicate); if (arc != null) tree.addArgument(new SRLArc(node, arc.getLabel(), arc.getNumberedArgumentTag())); } return tree; } // ====================================== Gold tags ====================================== /** * Return all POS tags of this DEPTree * @return */ public String[] getPOSTags() { int i, size = size(); String[] tags = new String[size]; for (i=1; i<size; i++) tags[i] = get(i).getPOSTag(); return tags; } /** * Set the POS tags of the DEPNode of this DEPTree to the given String[] of POS tags * @param tags */ public void setPOSTags(String[] tags) { int i, size = size(); for (i=1; i<size; i++) get(i).setPOSTag(tags[i]); } /** * Return all named entity tags in this DEPTree * @return */ public String[] getNamedEntityTags() { int i, size = size(); String[] tags = new String[size]; for (i=1; i<size; i++) tags[i] = get(i).getNamedEntityTag(); return tags; } /** * Set the named entity tags of the DEPNode of this DEPTree to the given String[] of named entity tags * @param tags */ public void setNamedEntityTags(String[] tags) { int i, size = size(); for (i=1; i<size; i++) get(i).setNamedEntityTag(tags[i]); } /** * Return an array of all dependency arcs in this DEPTree * @return */ public DEPArc[] getHeads() { return getHeads(size()); } /** * Return an array of all dependency arcs in this DEPTree ending at this index * @param endIndex (exclusive). * @return */ public DEPArc[] getHeads(int endIndex) { DEPNode node, head; int i; DEPArc[] heads = new DEPArc[endIndex]; heads[0] = new DEPArc(null, null); for (i=1; i<endIndex; i++) { node = get(i); head = node.getHead(); heads[i] = (head != null) ? new DEPArc(head, node.getLabel()) : new DEPArc(null, null); } return heads; } /** * Starting from top of DEPTree set the heads of the DEPNodes to the DEPArc given * @param arcs */ public void setHeads(DEPArc[] arcs) { int i, len = arcs.length; DEPNode node; DEPArc arc; clearDependencies(); for (i=1; i<len; i++) { node = get(i); arc = arcs[i]; if (arc.getNode() != null) node.setHead(arc.getNode(), arc.getLabel()); } } /** * Remove all dependences of all DEPNodes in this DEPTree */ public void clearDependencies() { int i, size = size(); for (i=0; i<size; i++) get(i).clearDependencies(); } /** * Return all POS tags of this given feature * @param key * @return */ public String[] getFeatureTags(String key) { int i, size = size(); String[] tags = new String[size]; for (i=1; i<size; i++) tags[i] = get(i).getFeat(key); return tags; } /** * Starting at the top of this DEPTree set this DEPNode with this feature label and this tag * @param key * @param tags */ public void setFeatureTags(String key, String[] tags) { int i, size = size(); for (i=1; i<size; i++) get(i).putFeat(key, tags[i]); } /** * Starting at the top of this DEPTree remove this feature with this specific key * @param key */ public void clearFeatureTags(String key) { int i, size = size(); for (i=1; i<size; i++) get(i).removeFeat(key); } /** * Return a String[] with all role set IDs * @return */ public String[] getRolesetIDs() { int i, size = size(); String[] rolesets = new String[size]; for (i=1; i<size; i++) rolesets[i] = get(i).getRolesetID(); return rolesets; } /** * Starting at the top of this DEPTree clear all role set IDs */ public void clearRolesetIDs() { int i, size = size(); for (i=1; i<size; i++) get(i).clearRolesetID(); } /** * 2-Dimension array 'i'th row is the DEPNode and 'j'ith column is the semantic head * @return */ public SRLArc[][] getSemanticHeads() { int i, j, len, size = size(); List<SRLArc> arcs; SRLArc[] heads; SRLArc[][] sHeads = new SRLArc[size][]; sHeads[0] = new SRLArc[0]; for (i=1; i<size; i++) { arcs = get(i).getSemanticHeadArcList(); len = arcs.size(); heads = new SRLArc[len]; for (j=0; j<len; j++) heads[j] = new SRLArc(arcs.get(j)); sHeads[i] = heads; } return sHeads; } public void setSemanticHeads(SRLArc[][] semanticArcs) { int i, len = semanticArcs.length; SRLArc[] arcs; DEPNode node; clearSemanticHeads(); for (i=1; i<len; i++) { arcs = semanticArcs[i]; node = get(i); for (SRLArc arc : arcs) node.addSemanticHead(arc); } } /** * Starting at the top of the DEPTree clear all the semantic heads */ public void clearSemanticHeads() { int i, size = size(); for (i=1; i<size; i++) get(i).clearSemanticHeads(); } /** * Return a list of DEPNodes from a DFS traversal of this DEPTree * @return */ public List<DEPNode> getDepthFirstNodeList() { List<DEPNode> list = new ArrayList<>(size()); traverseDepthFirst(list, get(DEPLib.ROOT_ID)); return list; } /** * Recursive call that does a DFS traversal of this DEPNode's children * @param list * @param node */ private void traverseDepthFirst(List<DEPNode> list, DEPNode node) { for (DEPNode child : node.getDependentList()) traverseDepthFirst(list, child); list.add(node); } public int countHeaded() { int c = 0; for (DEPNode node : this) if (node.hasHead()) c++; return c; } // ====================================== String ====================================== @Override public String toString() { return toString(DEPNode::toString); } public String toString(Function<DEPNode,String> f) { StringJoiner build = new StringJoiner(StringConst.NEW_LINE); for (DEPNode node : this) build.add(f.apply(node)); return build.toString(); } @Override public Iterator<DEPNode> iterator() { Iterator<DEPNode> it = new Iterator<DEPNode>() { private int current_index = 1; @Override public boolean hasNext() { return current_index < size(); } @Override public DEPNode next() { return d_tree[current_index++]; } @Override public void remove() {} }; return it; } public DEPNode[] toNodeArray() { return d_tree; } /** * @param beginIndex inclusive. * @param endIndex exclusive. */ public String join(Function<DEPNode,String> f, String delim, int beginIndex, int endIndex) { StringJoiner joiner = new StringJoiner(delim); for (int i=beginIndex; i<endIndex; i++) joiner.add(f.apply(get(i))); return joiner.toString(); } // ====================================== String ====================================== public Set<String> getNgrams(Function<DEPNode,String> f, String delim, int n) { Set<String> ngrams = new HashSet<>(); int i, j, size = size(); for (i=1; i<size; i++) for (j=0; j<n; j++) if (i-j > 0) ngrams.add(join(f, delim, i-j, i+1)); return ngrams; } public Set<String> getNgrams(Function<DEPNode,String> f1, Function<DEPNode,String> f2, String delim, int n) { Set<String> ngrams = new HashSet<>(); int i, j, k, l, c, size = size(); StringJoiner joiner; for (i=1; i<size; i++) for (j=1; j<n; j++) if (i-j > 0) for (l=0; l<=j; l++) { joiner = new StringJoiner(delim); for (k=i-j,c=0; k<=i; k++,c++) { if (l == c) joiner.add(f1.apply(get(k))); else joiner.add(f2.apply(get(k))); } ngrams.add(joiner.toString()); } return ngrams; } // public Set<String> getNgrams(Function<DEPNode,String> f, String delim, int n, boolean excludeSymbols) // { // List<String> list = new ArrayList<>(); // String s; // // for (DEPNode node : this) // { // s = f.apply(node); // if (!excludeSymbols || !StringUtils.containsPunctuationOnly(s)) list.add(s); // } // // Set<String> ngrams = new HashSet<>(list); // int i, j, size = list.size(); // // for (i=0; i<size; i++) // for (j=1; j<n; j++) // if (i-j >= 0) ngrams.add(Joiner.join(list, delim, i-j, i+1)); // // return ngrams; // } }