package edu.berkeley.nlp.PCFGLA;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import edu.berkeley.nlp.ling.CollinsHeadFinder;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.syntax.Trees;
import edu.berkeley.nlp.util.Filter;
/**
* Class which contains code for annotating and binarizing trees for the parser's use, and debinarizing and unannotating them for scoring.
*/
public class TreeAnnotations implements java.io.Serializable {
private static final long serialVersionUID = 1L;
static CollinsHeadFinder headFinder = new CollinsHeadFinder();
/** This annotates the parse tree by adding ancestors to the tags, and then by forgetfully binarizing the tree.
* The format goes as follows:
* Tag becomes Tag^Parent^Grandparent
* Then, this is binarized, so that
* Tag^Parent^Grandparent produces A^Tag^Parent B... C...
* becomes
* Tag^Parent^Grandparent produces A^Tag^Parent @Tag^Parent^Grandparent->_A^Tag^Parent
* @Tag^Parent^Grandparent->_A^Tag^Parent produces B^Tag^Parent @Tag^Parent^Grandparent->_A^Tag^Parent_B^Tag^Parent
* and finally we trim the excess _* off to control the amount of horizontal history
*
* */
public static Tree<String> processTree(Tree<String> unAnnotatedTree,
int nVerticalAnnotations, int nHorizontalAnnotations,
Binarization binarization, boolean manualAnnotation) {
return processTree(unAnnotatedTree, nVerticalAnnotations, nHorizontalAnnotations, binarization, manualAnnotation,false, true);
}
public static Tree<String> processTree(Tree<String> unAnnotatedTree,
int nVerticalAnnotations, int nHorizontalAnnotations,
Binarization binarization, boolean manualAnnotation,
boolean annotateUnaryParents, boolean markGrammarSymbols) {
Tree<String> verticallyAnnotated = unAnnotatedTree;
if (nVerticalAnnotations==3) {
verticallyAnnotated = annotateVerticallyTwice(unAnnotatedTree, "", "");
}
else if (nVerticalAnnotations==2) {
if (manualAnnotation) {
verticallyAnnotated = annotateManuallyVertically(unAnnotatedTree, "");
} else {
verticallyAnnotated = annotateVertically(unAnnotatedTree, "");
}
}
else if (nVerticalAnnotations==1){
if (markGrammarSymbols) verticallyAnnotated = markGrammarNonterminals(unAnnotatedTree,"");
if (annotateUnaryParents) verticallyAnnotated = markUnaryParents(verticallyAnnotated);
}
else {
throw new Error("the code does not exist to annotate vertically "+nVerticalAnnotations+" times");
}
Tree<String> binarizedTree = binarizeTree(verticallyAnnotated,binarization);
//removeUnaryChains(binarizedTree);
//System.out.println(binarizedTree);
// if (deleteLabels) return deleteLabels(binarizedTree,true);
// else if (deletePC) return deletePC(binarizedTree,true);
// else
return forgetLabels(binarizedTree,nHorizontalAnnotations);
}
/**
* Binarize a tree with the given binarization style; e.g. head binarization, left binarization, etc.
*
* @param tree
* @param binarization The type of binarization used.
* @return
*/
public static Tree<String> binarizeTree(Tree<String> tree, Binarization binarization) {
switch(binarization) {
case LEFT:
return leftBinarizeTree(tree);
case RIGHT:
return rightBinarizeTree(tree);
case PARENT:
return parentBinarizeTree(tree);
case HEAD:
return headBinarizeTree(tree);
}
return null;
}
private static Tree<String> annotateVerticallyTwice(Tree<String> tree, String parentLabel1, String parentLabel2){
Tree<String> verticallyMarkovizatedTree;
if (tree.isLeaf()) {
verticallyMarkovizatedTree = tree; //new Tree<String>(tree.getLabel());// + parentLabel);
} else {
List<Tree<String>> children = new ArrayList<Tree<String>>();
for (Tree<String> child : tree.getChildren()) {
// children.add(annotateVerticallyTwice(child, parentLabel2,"^"+tree.getLabel()));
children.add(annotateVerticallyTwice(child, "^"+tree.getLabel(), parentLabel1));
}
verticallyMarkovizatedTree = new Tree<String>(tree.getLabel() + parentLabel1+parentLabel2,children);
}
return verticallyMarkovizatedTree;
}
private static Tree<String> annotateVertically(Tree<String> tree, String parentLabel){
Tree<String> verticallyMarkovizatedTree;
if (tree.isLeaf()){
verticallyMarkovizatedTree = tree;//new Tree<String>(tree.getLabel());// + parentLabel);
}
else {
List<Tree<String>> children = new ArrayList<Tree<String>>();
for (Tree<String> child : tree.getChildren()) {
children.add(annotateVertically(child, "^"+tree.getLabel()));
}
verticallyMarkovizatedTree = new Tree<String>(tree.getLabel() + parentLabel,children);
}
return verticallyMarkovizatedTree;
}
private static Tree<String> markGrammarNonterminals(Tree<String> tree, String parentLabel){
Tree<String> verticallyMarkovizatedTree;
if (tree.isPreTerminal()){
verticallyMarkovizatedTree = tree;//new Tree<String>(tree.getLabel());// + parentLabel);
}
else {
List<Tree<String>> children = new ArrayList<Tree<String>>();
for (Tree<String> child : tree.getChildren()) {
children.add(markGrammarNonterminals(child, "^g"));//""));//
}
verticallyMarkovizatedTree = new Tree<String>(tree.getLabel() + parentLabel,children);
}
return verticallyMarkovizatedTree;
}
private static Tree<String> markUnaryParents(Tree<String> tree){
Tree<String> verticallyMarkovizatedTree;
if (tree.isPreTerminal()){
verticallyMarkovizatedTree = tree;//new Tree<String>(tree.getLabel());// + parentLabel);
}
else {
List<Tree<String>> children = new ArrayList<Tree<String>>();
for (Tree<String> child : tree.getChildren()) {
children.add(markUnaryParents(child));//
}
String add = "";
if (!tree.getLabel().equals("ROOT"))
add = (children.size() == 1 ? "^u" : "");
verticallyMarkovizatedTree = new Tree<String>(tree.getLabel() + add,children);
}
return verticallyMarkovizatedTree;
}
private static Tree<String> annotateManuallyVertically(Tree<String> tree, String parentLabel){
Tree<String> verticallyMarkovizatedTree;
if (tree.isPreTerminal()){
// split only some of the POS tags
// DT, RB, IN, AUX, CC, %
String label = tree.getLabel();
if (label.contains("DT") || label.contains("RB") ||
label.contains("IN") || label.contains("AUX") ||
label.contains("CC") || label.contains("%") ){
verticallyMarkovizatedTree = new Tree<String>(tree.getLabel() + parentLabel,tree.getChildren());
} else {
verticallyMarkovizatedTree = tree;//new Tree<String>(tree.getLabel());// + parentLabel);
}
}
else {
List<Tree<String>> children = new ArrayList<Tree<String>>();
for (Tree<String> child : tree.getChildren()) {
children.add(annotateManuallyVertically(child, "^"+tree.getLabel()));
}
verticallyMarkovizatedTree = new Tree<String>(tree.getLabel() + parentLabel,children);
}
return verticallyMarkovizatedTree;
}
// replaces labels with three types of labels:
// X, @X=Y and Z
private static Tree<String> deleteLabels(Tree<String> tree, boolean isRoot) {
String label = tree.getLabel();
String newLabel = "";
if (isRoot){
newLabel = label;
}
else if (tree.isPreTerminal()){
newLabel = "Z";
return new Tree<String>(newLabel,tree.getChildren());
}
else if (label.charAt(0)=='@') {
newLabel = "@X";
}
else newLabel = "X";
List<Tree<String>> transformedChildren = new ArrayList<Tree<String>>();
for (Tree<String> child : tree.getChildren()) {
transformedChildren.add(deleteLabels(child, false));
}
return new Tree<String>(newLabel, transformedChildren);
}
// replaces phrasal categories with
// X, @X=Y but keeps POS-tags
private static Tree<String> deletePC(Tree<String> tree, boolean isRoot) {
String label = tree.getLabel();
String newLabel = "";
if (isRoot){
newLabel = label;
}
else if (tree.isPreTerminal()){
return tree;
}
else if (label.charAt(0)=='@') {
newLabel = "@X";
}
else newLabel = "X";
List<Tree<String>> transformedChildren = new ArrayList<Tree<String>>();
for (Tree<String> child : tree.getChildren()) {
transformedChildren.add(deletePC(child, false));
}
return new Tree<String>(newLabel, transformedChildren);
}
private static Tree<String> forgetLabels(Tree<String> tree, int nHorizontalAnnotation) {
if (nHorizontalAnnotation==-1) return tree;
String transformedLabel = tree.getLabel();
if (tree.isLeaf()) {
return new Tree<String>(transformedLabel);
}
//the location of the farthest _
int firstCutIndex = transformedLabel.indexOf('_');
int keepBeginning = firstCutIndex;
//will become -1 when the end of the line is reached
int secondCutIndex = transformedLabel.indexOf('_',firstCutIndex+1);
//the location of the second farthest _
int cutIndex = secondCutIndex;
while (secondCutIndex != -1) {
cutIndex = firstCutIndex;
firstCutIndex = secondCutIndex;
secondCutIndex = transformedLabel.indexOf('_',firstCutIndex+1);
}
if (nHorizontalAnnotation == 0) {
cutIndex = transformedLabel.indexOf('>')-1;
if (cutIndex > 0) transformedLabel = transformedLabel.substring(0,cutIndex);
} else if (cutIndex > 0 && !tree.isLeaf()) {
if (nHorizontalAnnotation == 2) {
transformedLabel = transformedLabel.substring(0, keepBeginning) + transformedLabel.substring(cutIndex);
} else if (nHorizontalAnnotation == 1) {
transformedLabel = transformedLabel.substring(0, keepBeginning) + transformedLabel.substring(firstCutIndex);
} else {
throw new Error("code does not exist to horizontally annotate at level "+ nHorizontalAnnotation);
}
}
List<Tree<String>> transformedChildren = new ArrayList<Tree<String>>();
for (Tree<String> child : tree.getChildren()) {
transformedChildren.add(forgetLabels(child,nHorizontalAnnotation));
}
/*if (!transformedLabel.equals("ROOT")&& transformedLabel.length()>1){
transformedLabel = transformedLabel.substring(0,2);
}*/
/*if (tree.isPreTerminal() && transformedLabel.length()>1){
if (transformedLabel.substring(0,2).equals("NN")){
transformedLabel = "NNX";
}
else if (transformedLabel.equals("VBZ") || transformedLabel.equals("VBP") || transformedLabel.equals("VBD") || transformedLabel.equals("VB") ){
transformedLabel = "VBX";
}
else if (transformedLabel.substring(0,3).equals("PRP")){
transformedLabel = "PRPX";
}
else if (transformedLabel.equals("JJR") || transformedLabel.equals("JJS") ){
transformedLabel = "JJX";
}
else if (transformedLabel.equals("RBR") || transformedLabel.equals("RBS") ){
transformedLabel = "RBX";
}
else if (transformedLabel.equals("WDT") || transformedLabel.equals("WP") || transformedLabel.equals("WP$")){
transformedLabel = "WBX";
}
}*/
return new Tree<String>(transformedLabel, transformedChildren);
}
static Tree<String> leftBinarizeTree(Tree<String> tree) {
String label = tree.getLabel();
List<Tree<String>> children = tree.getChildren();
if (tree.isLeaf())
return new Tree<String>(label);
else if (children.size() == 1) {
return new Tree<String>(label, Collections.singletonList(leftBinarizeTree(children.get(0))));
}
// otherwise, it's a binary-or-more local tree, so decompose it into a sequence of binary and unary trees.
String intermediateLabel = "@"+label+"->";
Tree<String> intermediateTree = leftBinarizeTreeHelper(tree, 0, intermediateLabel);
return new Tree<String>(label, intermediateTree.getChildren());
}
private static Tree<String> leftBinarizeTreeHelper(Tree<String> tree, int numChildrenGenerated, String intermediateLabel) {
Tree<String> leftTree = tree.getChildren().get(numChildrenGenerated);
List<Tree<String>> children = new ArrayList<Tree<String>>(2);
children.add(leftBinarizeTree(leftTree));
if (numChildrenGenerated == tree.getChildren().size() - 2) {
children.add(leftBinarizeTree(tree.getChildren().get(numChildrenGenerated+1)));
} else if (numChildrenGenerated < tree.getChildren().size() - 2) {
Tree<String> rightTree = leftBinarizeTreeHelper(tree, numChildrenGenerated+1, intermediateLabel+"_"+leftTree.getLabel());
children.add(rightTree);
}
return new Tree<String>(intermediateLabel, children);
}
static Tree<String> rightBinarizeTree(Tree<String> tree) {
String label = tree.getLabel();
List<Tree<String>> children = tree.getChildren();
if (tree.isLeaf())
return new Tree<String>(label);
else if (children.size() == 1) {
return new Tree<String>(label, Collections
.singletonList(rightBinarizeTree(children.get(0))));
}
// otherwise, it's a binary-or-more local tree, so decompose it into a
// sequence of binary and unary trees.
String intermediateLabel = "@" + label + "->";
Tree<String> intermediateTree = rightBinarizeTreeHelper(tree, children
.size() - 1, intermediateLabel);
return new Tree<String>(label, intermediateTree.getChildren());
}
private static Tree<String> rightBinarizeTreeHelper(Tree<String> tree,
int numChildrenLeft, String intermediateLabel) {
Tree<String> rightTree = tree.getChildren().get(numChildrenLeft);
List<Tree<String>> children = new ArrayList<Tree<String>>(2);
if (numChildrenLeft == 1) {
children.add(rightBinarizeTree(tree.getChildren()
.get(numChildrenLeft - 1)));
} else if (numChildrenLeft > 1) {
Tree<String> leftTree = rightBinarizeTreeHelper(tree,
numChildrenLeft - 1, intermediateLabel + "_" + rightTree.getLabel());
children.add(leftTree);
}
children.add(rightBinarizeTree(rightTree));
return new Tree<String>(intermediateLabel, children);
}
/**
* Binarize a tree around the head symbol. That is, when there is an n-ary
* rule, with n > 2, we split it into a series of binary rules with titles
* like [AT]JJ-R (if JJ is the head of the rule). The right part of the symbol
* (-R or -L) is used to indicate whether we're producing to the right or to
* the left of the head symbol. Thus, the head symbol is always the deepest
* symbol on the tree we've created.
*
* @param tree
* @return
*/
static Tree<String> headBinarizeTree(Tree<String> tree) {
return headParentBinarizeTree(Binarization.HEAD,tree);
}
/**
* Binarize a tree around the head symbol, but with symbol names derived from
* the parent symbol, rather than the head (as in {@link headBinarizeTree}).
*
* @param tree
* @return
*/
static Tree<String> parentBinarizeTree(Tree<String> tree) {
return headParentBinarizeTree(Binarization.PARENT,tree);
}
/**
* Binarize a tree around its head, with the symbol names derived from either
* the parent or the head (as determined by binarization).
* <p>
* It calls {@link @headParentBinarizeTreeHelper} to do the messy work of
* binarization when that is actually necessary.
*
* @param binarization
* This determines whether the newly-created symbols are based on the
* head symbol or on the parent symbol. It should be either
* headBinarize or parentBinarize.
* @param tree
* @return
*/
private static Tree<String> headParentBinarizeTree(Binarization binarization, Tree<String> tree) {
List<Tree<String>> children = tree.getChildren();
if (children.size()==0) {
return tree;
} else if (children.size()==1) {
List<Tree<String>> kids = new ArrayList<Tree<String>>(1);
kids.add(headParentBinarizeTree(binarization,children.get(0)));
return new Tree<String>(tree.getLabel(),kids);
} else if (children.size()==2) {
List<Tree<String>> kids = new ArrayList<Tree<String>>(2);
kids.add(headParentBinarizeTree(binarization,children.get(0)));
kids.add(headParentBinarizeTree(binarization,children.get(1)));
return new Tree<String>(tree.getLabel(),kids);
} else {
List<Tree<String>> kids = new ArrayList<Tree<String>>(1);
kids.add(headParentBinarizeTreeHelper(binarization, tree, 0, children.size() - 1,
headFinder.determineHead(tree), false, ""));
return new Tree<String>(tree.getLabel(),kids);
}
}
/**
* This binarizes a tree into a bunch of binary [at]SYM-R symbols. It assumes
* that this sort of binarization is always necessary, so it is only called by
* {@link headParentBinarizeTree}.
*
* @param binarization
* The type of new symbols to generate, either head or parent.
* @param tree
* @param leftChild
* The index of the leftmost child remaining to be binarized.
* @param rightChild
* The index of the rightmost child remaining to be binarized.
* @param head
* The head symbol of this level of the tree.
* @param right
* This indicates whether we have gotten to the right of the head
* child yet.
* @return
*/
static Tree<String> headParentBinarizeTreeHelper(Binarization binarization,
Tree<String> tree, int leftChild, int rightChild, Tree<String> head,
boolean right, String productionHistory) {
if (head==null)
throw new Error("head is null");
List<Tree<String>> children = tree.getChildren();
//test if we've finally come to the head word
if (!right && children.get(leftChild)==head)
right=true;
//prepare the parent label
String label = null;
if (binarization==Binarization.HEAD) {
label = head.getLabel();
} else if (binarization==Binarization.PARENT) {
label = tree.getLabel();
}
String parentLabel = "@"+label+ (right ? "-R" : "-L")+"->"+productionHistory;
//if the left child == the right child, then we only need a unary
if (leftChild==rightChild) {
ArrayList<Tree<String>> kids = new ArrayList<Tree<String>>(1);
kids.add(headParentBinarizeTree(binarization, children.get(leftChild)));
return new Tree<String>(parentLabel,kids);
}
//if we're to the left of the head word
if (!right) {
ArrayList<Tree<String>> kids = new ArrayList<Tree<String>>(2);
Tree<String> child = children.get(leftChild);
kids.add(headParentBinarizeTree(binarization, child));
kids.add(headParentBinarizeTreeHelper(binarization, tree,leftChild+1,rightChild,head,right,productionHistory+"_"+child.getLabel()));
return new Tree<String>(parentLabel,kids);
}
//if we're to the right of the head word
else {
ArrayList<Tree<String>> kids = new ArrayList<Tree<String>>(2);
Tree<String> child = children.get(rightChild);
kids.add(headParentBinarizeTreeHelper(binarization, tree,leftChild,rightChild-1,head,right,productionHistory+"_"+child.getLabel()));
kids.add(headParentBinarizeTree(binarization, child));
return new Tree<String>(parentLabel,kids);
}
}
public static Tree<String> unAnnotateTreeSpecial(Tree<String> annotatedTree) {
// Remove intermediate nodes (labels beginning with "Y"
// Remove all material on node labels which follow their base symbol (cuts at the leftmost -, ^, or : character)
// Examples: a node with label @NP->DT_JJ will be spliced out, and a node with label NP^S will be reduced to NP
Tree<String> debinarizedTree = Trees.spliceNodes(annotatedTree, new Filter<String>() {
public boolean accept(String s) {
return s.startsWith("Y");
}
});
Tree<String> unAnnotatedTree = (new Trees.FunctionNodeStripper()).transformTree(debinarizedTree);
return unAnnotatedTree;
}
public static Tree<String> unAnnotateTree(Tree<String> annotatedTree) {
// Remove intermediate nodes (labels beginning with "@"
// Remove all material on node labels which follow their base symbol (cuts at the leftmost -, ^, or : character)
// Examples: a node with label @NP->DT_JJ will be spliced out, and a node with label NP^S will be reduced to NP
Tree<String> debinarizedTree = Trees.spliceNodes(annotatedTree, new Filter<String>() {
public boolean accept(String s) {
return s.startsWith("@") && !s.equals("@");
}
});
Tree<String> unAnnotatedTree = (new Trees.FunctionNodeStripper()).transformTree(debinarizedTree);
return unAnnotatedTree;
}
public static void main(String args[]) {
//test the binarization
Trees.PennTreeReader reader = new Trees.PennTreeReader(new StringReader("((S (NP (DT the) (JJ quick) (JJ (AA (BB (CC brown)))) (NN fox)) (VP (VBD jumped) (PP (IN over) (NP (DT the) (JJ lazy) (NN dog)))) (. .)))"));
Tree<String> tree = reader.next();
System.out.println("tree");
System.out.println(Trees.PennTreeRenderer.render(tree));
for (Binarization binarization : Binarization.values()) {
System.out.println("binarization type "+binarization.name());
//print the binarization
try {
Tree<String> binarizedTree = binarizeTree(tree,binarization);
System.out.println(Trees.PennTreeRenderer.render(binarizedTree));
System.out.println("unbinarized");
Tree<String> unBinarizedTree = unAnnotateTree(binarizedTree);
System.out.println(Trees.PennTreeRenderer.render(unBinarizedTree));
System.out.println("------------");
} catch (Error e) {
System.out.println("binarization not implemented");
}
}
}
public static Tree<String> removeSuperfluousNodes(Tree<String> tree){
if (tree.isPreTerminal()) return tree;
if (tree.isLeaf()) return tree;
List<Tree<String>> gChildren = tree.getChildren();
if (gChildren.size()!=1) {
// nothing to do, just recurse
ArrayList<Tree<String>> children = new ArrayList<Tree<String>>();
for (int i=0; i<gChildren.size(); i++){
Tree<String> cChild = removeSuperfluousNodes(tree.getChildren().get(i));
children.add(cChild);
}
tree.setChildren(children);
return tree;
}
Tree<String> result = null;
String parent = tree.getLabel();
HashSet<String> nodesInChain = new HashSet<String>();
tree = tree.getChildren().get(0);
while (!tree.isPreTerminal() && tree.getChildren().size()==1){
if (!nodesInChain.contains(tree.getLabel())){
nodesInChain.add(tree.getLabel());
}
tree = tree.getChildren().get(0);
}
Tree<String> child = removeSuperfluousNodes(tree);
String cLabel = child.getLabel();
ArrayList<Tree<String>> childs = new ArrayList<Tree<String>>();
childs.add(child);
if (cLabel.equals(parent)) {
result = child;
} else {
result = new Tree<String>(parent,childs);
}
for (String node : nodesInChain){
if (node.equals(parent)||node.equals(cLabel)) continue;
Tree<String> intermediate = new Tree<String>(node, result.getChildren());
childs = new ArrayList<Tree<String>>();
childs.add(intermediate);
result.setChildren(childs);
}
return result;
}
public static void displayUnaryChains(Tree<String> tree, String parent){
if (tree.getChildren().size()==1){
if (!parent.equals("") && !tree.isPreTerminal()) System.out.println("Unary chain: "+parent+" -> "+ tree.getLabel() + " -> " +tree.getChildren().get(0).getLabel());
if (!tree.isPreTerminal()) displayUnaryChains(tree.getChildren().get(0),tree.getLabel());
}
else {
for (Tree<String> child : tree.getChildren()){
if (!child.isPreTerminal()) displayUnaryChains(child,"");
}
}
}
public static void removeUnaryChains(Tree<String> tree){
if (tree.isPreTerminal()) return;
if (tree.getChildren().size()==1 && tree.getChildren().get(0).getChildren().size()==1){
// unary chain
if (tree.getChildren().get(0).isPreTerminal()) return; // if we are just above a preterminal, dont do anything
else {// otherwise remove the intermediate node
ArrayList<Tree<String>> newChildren = new ArrayList<Tree<String>>();
newChildren.add(tree.getChildren().get(0).getChildren().get(0));
tree.setChildren(newChildren);
}
}
for (Tree<String> child : tree.getChildren()){
removeUnaryChains(child);
}
}
}