/**
*
*/
package edu.berkeley.nlp.scripts;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import edu.berkeley.nlp.PCFGLA.Binarization;
import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.StateSetTreeList;
import edu.berkeley.nlp.PCFGLA.TreeAnnotations;
import edu.berkeley.nlp.PCFGLA.smoothing.NoSmoothing;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Trees;
import edu.berkeley.nlp.syntax.Trees.PennTreeReader;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.parser.EnglishPennTreebankParseEvaluator;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.Filter;
import edu.berkeley.nlp.util.Numberer;
/**
* Takes a treebank with observed split categories and puts it into our format
* @author petrov
*
*/
public class GermanSharedTask {
Numberer tagNumberer;
List<Numberer> substateNumberers;
public Grammar extractGrammar(List<Tree<String>> trainTrees){
tagNumberer = Numberer.getGlobalNumberer("tags");
substateNumberers = new ArrayList<Numberer>();
short[] numSubStates = countSymbols(trainTrees);
List<Tree<String>> trainTreesNoGF = stripOffGF(trainTrees);
StateSetTreeList stateSetTrees = new StateSetTreeList(trainTreesNoGF, numSubStates, false, tagNumberer);
Grammar grammar = createGrammar(stateSetTrees, trainTrees, numSubStates);
return grammar;
}
private void checkGrammar(Grammar grammar, List<Tree<String>> trainTrees, List<Tree<String>> goldTrees) {
EnglishPennTreebankParseEvaluator.LabeledConstituentEval<String> eval = new EnglishPennTreebankParseEvaluator.LabeledConstituentEval<String>(new HashSet<String>(Arrays.asList(new String[] {"ROOT","PSEUDO"})), new HashSet<String>(Arrays.asList(new String[] {"''", "``", ".", ":", ","})));
List<Tree<String>> trainTreesNoGF = stripOffGF(trainTrees);
StateSetTreeList stateSetTrees = new StateSetTreeList(trainTreesNoGF, grammar.numSubStates, false, tagNumberer);
int index = 0;
for (Tree<StateSet> stateSetTree : stateSetTrees){
Tree<String> goldTree = goldTrees.get(index++);
while (goldTree.getYield().size()!=stateSetTree.getYield().size()&&index<=goldTrees.size()){
goldTree = goldTrees.get(index++);
}
List<String> goldPOS = goldTree.getPreTerminalYield();
Tree<String> labeledTree = guessGF(stateSetTree, grammar, goldPOS);
Tree<String> debinarizedTree = Trees.spliceNodes(labeledTree, new Filter<String>() {
public boolean accept(String s) {
return s.startsWith("@");
}
});
Tree<String> goldDebTree = Trees.spliceNodes(goldTree, new Filter<String>() {
public boolean accept(String s) {
return s.startsWith("@");
}
});
eval.evaluate(goldDebTree, debinarizedTree);
int t = 1;
t++;
}
eval.display(true);
}
private void labelTrees(Grammar grammar, List<Tree<String>> trainTrees, List<List<String>> goldPOStags) {
List<Tree<String>> trainTreesNoGF = stripOffGF(trainTrees);
StateSetTreeList stateSetTrees = new StateSetTreeList(trainTreesNoGF, grammar.numSubStates, false, tagNumberer);
int index = 0;
for (Tree<StateSet> stateSetTree : stateSetTrees){
List<String> goldPOS = goldPOStags.get(index++);
Tree<String> labeledTree = guessGF(stateSetTree, grammar, goldPOS);
Tree<String> debinarizedTree = Trees.spliceNodes(labeledTree, new Filter<String>() {
public boolean accept(String s) {
return s.startsWith("@");
}
});
System.out.println(debinarizedTree+"\n");
}
}
/**
* @param stateSetTree
* @param grammar
* @param goldPOS
* @return
*/
private Tree<String> guessGF(Tree<StateSet> stateSetTree, Grammar grammar, List<String> goldPOS) {
doInsideScores(stateSetTree, grammar, goldPOS);
return extractBestViterbiDerivation(grammar,stateSetTree,0);
}
private List<Tree<String>> stripOffGF(List<Tree<String>> trainTrees) {
List<Tree<String>> trainTreesNoGF = new ArrayList<Tree<String>>(trainTrees.size());
for (Tree<String> tree : trainTrees){
trainTreesNoGF.add(tree.shallowClone());
}
for (Tree<String> tree : trainTreesNoGF){
for (Tree<String> node : tree.getPostOrderTraversal()){
if (tree.isLeaf()) continue;
String label = node.getLabel();
int cutIndex = label.indexOf('-');
if (cutIndex!=-1) label = label.substring(0,cutIndex);
node.setLabel(label);
}
}
return trainTreesNoGF;
}
private Grammar createGrammar(StateSetTreeList stateSetTrees, List<Tree<String>> trainTrees, short[] numSubStates) {
Grammar grammar = new Grammar(numSubStates, false, new NoSmoothing(), null, -1);
int index = 0;
for (Tree<StateSet> stateSetTree : stateSetTrees){
Tree<String> tree = trainTrees.get(index++);
setScores(stateSetTree, tree);
grammar.tallyStateSetTree(stateSetTree, grammar);
}
grammar.optimize(0); // M Step
return grammar;
}
private void setScores(Tree<StateSet> stateSetTree, Tree<String> tree) {
if (tree.isLeaf()) return;
String[] labels = splitLabel(tree.getLabel());
StateSet stateSet = stateSetTree.getLabel();
int substate = substateNumberers.get(stateSet.getState()).number(labels[1]);
stateSet.setIScore(substate, 1.0);
stateSet.setIScale(0);
stateSet.setOScore(substate, 1.0);
stateSet.setOScale(0);
int nChildren = tree.getChildren().size();
if (nChildren != stateSetTree.getChildren().size()) System.err.println("Mismatch!");
for (int i=0; i<nChildren; i++){
setScores(stateSetTree.getChildren().get(i), tree.getChildren().get(i));
}
}
private short[] countSymbols(List<Tree<String>> trainTrees) {
for (Tree<String> tree : trainTrees){
processTree(tree);
}
short[] numSubStates = new short[tagNumberer.total()];
for (int substate=0; substate<numSubStates.length; substate++){
numSubStates[substate] = (short)substateNumberers.get(substate).total();
}
return numSubStates;
}
private void processTree(Tree<String> tree) {
String[] labels = splitLabel(tree.getLabel());
int state = tagNumberer.number(labels[0]);
if (state >= substateNumberers.size()) {
substateNumberers.add(new Numberer());
}
substateNumberers.get(state).number(labels[1]);
for (Tree<String> child : tree.getChildren()){
if (!child.isLeaf()) processTree(child);
}
}
/**
* @param label
* @return
*/
private String[] splitLabel(String label) {
String[] labels = label.split("-");
if (labels.length==1) labels = new String[]{labels[0],""};
return labels;
}
Tree<String> extractBestViterbiDerivation(Grammar grammar, Tree<StateSet> tree, int substate){
if (tree.isLeaf()) return new Tree<String>(tree.getLabel().getWord());
if (substate==-1) substate=0;
if (tree.isPreTerminal()){
ArrayList<Tree<String>> child = new ArrayList<Tree<String>>();
child.add(extractBestViterbiDerivation(grammar, tree.getChildren().get(0),-1));
int state = tree.getLabel().getState();
String goalStr = (String)tagNumberer.object(state);
String gfStr = (String)substateNumberers.get(state).object(substate);
if (!gfStr.equals("")) goalStr = goalStr + "-" + gfStr;
return new Tree<String>(goalStr, child);
}
StateSet node = tree.getLabel();
short pState = node.getState();
ArrayList<Tree<String>> newChildren = new ArrayList<Tree<String>>();
List<Tree<StateSet>> children = tree.getChildren();
double myScore = node.getIScore(substate);
if (myScore==Double.NEGATIVE_INFINITY){
myScore = DoubleArrays.max(node.getIScores());
substate = DoubleArrays.argMax(node.getIScores());
}
switch (children.size()) {
case 1:
StateSet child = children.get(0).getLabel();
short cState = child.getState();
int nChildStates = child.numSubStates();
double[][] uscores = grammar.getUnaryScore(pState,cState);
int childIndex = -1;
for (int j = 0; j < nChildStates; j++) {
if (childIndex != -1) break;
if (uscores[j]!=null) {
double cS = child.getIScore(j);
if (cS==0) continue;
double rS = uscores[j][substate]; // rule score
if (rS==0) continue;
double res = rS * cS;
if (matches(res,myScore)){
childIndex = j;
}
}
}
newChildren.add(extractBestViterbiDerivation(grammar, children.get(0), childIndex));
break;
case 2:
StateSet leftChild = children.get(0).getLabel();
StateSet rightChild = children.get(1).getLabel();
int nLeftChildStates = leftChild.numSubStates();
int nRightChildStates = rightChild.numSubStates();
short lState = leftChild.getState();
short rState = rightChild.getState();
double[][][] bscores = grammar.getBinaryScore(pState,lState,rState);
int lChildIndex = -1, rChildIndex = -1;
for (int j = 0; j < nLeftChildStates; j++) {
if (lChildIndex!=-1 && rChildIndex!=-1) break;
double lcS = leftChild.getIScore(j);
if (lcS==0) continue;
for (int k = 0; k < nRightChildStates; k++) {
if (lChildIndex!=-1 && rChildIndex!=-1) break;
double rcS = rightChild.getIScore(k);
if (rcS==0) continue;
if (bscores[j][k]!=null) { // check whether one of the parents can produce these kids
double rS = bscores[j][k][substate];
if (rS==0) continue;
double res = rS * lcS * rcS;
if (matches(myScore,res)){
lChildIndex = j;
rChildIndex = k;
}
}
}
}
newChildren.add(extractBestViterbiDerivation(grammar, children.get(0), lChildIndex));
newChildren.add(extractBestViterbiDerivation(grammar, children.get(1), rChildIndex));
break;
default:
throw new Error ("Malformed tree: more than two children");
}
int state = node.getState();
String parentString = (String)tagNumberer.object(state);
if (parentString.endsWith("^g")) parentString = parentString.substring(0,parentString.length()-2);
String gfStr = (String)substateNumberers.get(state).object(substate);
if (!gfStr.equals("")) parentString = parentString + "-" + gfStr;
return new Tree<String>(parentString, newChildren);
}
protected boolean matches(double x, double y) {
return (Math.abs(x - y) / (Math.abs(x) + Math.abs(y) + 1e-10) < 1.0e-4);
}
void doInsideScores(Tree<StateSet> tree, Grammar grammar, List<String> goldPOS) {
if (tree.isLeaf()){
return;
}
List<Tree<StateSet>> children = tree.getChildren();
for (Tree<StateSet> child : children) {
if (!child.isLeaf()) doInsideScores(child, grammar, goldPOS);
}
StateSet parent = tree.getLabel();
short pState = parent.getState();
int nParentStates = parent.numSubStates();
if (tree.isPreTerminal()) {
// Plays a role similar to initializeChart()
String POS = goldPOS.get(parent.from);
String[] labels = splitLabel(POS);
int substate = 0;
if (pState<grammar.numStates){
substate = substateNumberers.get(pState).number(labels[1]);
if (substate>=grammar.numSubStates[pState]){
System.err.println("Have never seen this POS: "+POS);
substate=0;
}
} else {
parent = new StateSet((short)(grammar.numStates-1), (short)1);
tree.setLabel(parent);
}
parent.setIScore(substate, 1.0);
parent.scaleIScores(0);
} else {
switch (children.size()) {
case 0:
break;
case 1:
StateSet child = children.get(0).getLabel();
short cState = child.getState();
int nChildStates = child.numSubStates();
double[][] uscores = grammar.getUnaryScore(pState,cState);
double[] iScores = new double[nParentStates];
boolean foundOne = false;
for (int j = 0; j < nChildStates; j++) {
if (uscores[j]!=null) { //check whether one of the parents can produce this child
double cS = child.getIScore(j);
if (cS==0) continue;
for (int i = 0; i < nParentStates; i++) {
double rS = uscores[j][i]; // rule score
if (rS==0) continue;
double res = rS * cS;
/*if (res == 0) {
System.out.println("Prevented an underflow: rS "+rS+" cS "+cS);
res = Double.MIN_VALUE;
}*/
iScores[i] += res;
foundOne = true;
}
}
}
parent.setIScores(iScores);
parent.scaleIScores(child.getIScale());
break;
case 2:
StateSet leftChild = children.get(0).getLabel();
StateSet rightChild = children.get(1).getLabel();
int nLeftChildStates = leftChild.numSubStates();
int nRightChildStates = rightChild.numSubStates();
short lState = leftChild.getState();
short rState = rightChild.getState();
double[][][] bscores = grammar.getBinaryScore(pState,lState,rState);
double[] iScores2 = new double[nParentStates];
boolean foundOne2 = false;
for (int j = 0; j < nLeftChildStates; j++) {
double lcS = leftChild.getIScore(j);
if (lcS==0) continue;
for (int k = 0; k < nRightChildStates; k++) {
double rcS = rightChild.getIScore(k);
if (rcS==0) continue;
if (bscores[j][k]!=null) { // check whether one of the parents can produce these kids
for (int i = 0; i < nParentStates; i++) {
double rS = bscores[j][k][i];
if (rS==0) continue;
double res = rS * lcS * rcS;
/*if (res == 0) {
System.out.println("Prevented an underflow: rS "+rS+" lcS "+lcS+" rcS "+rcS);
res = Double.MIN_VALUE;
}*/
iScores2[i] += res;
foundOne2 = true;
}
}
}
}
parent.setIScores(iScores2);
parent.scaleIScores(leftChild.getIScale()+rightChild.getIScale());
break;
default:
throw new Error("Malformed tree: more than two children");
}
}
}
private static List<Tree<String>> loadTrees(String inputFile) {
InputStreamReader inputData = null;
try {
inputData = new InputStreamReader(new FileInputStream(inputFile), "UTF-8");
} catch (UnsupportedEncodingException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
PennTreeReader treeReader = new PennTreeReader(inputData);
List<Tree<String>> trainTrees = new ArrayList<Tree<String>>();
Tree<String> tree = null;
while(treeReader.hasNext()){
tree = treeReader.next();
// trainTrees.add(TreeAnnotations.processTree(tree, 1, 0, Binarization.LEFT, false, false, false));
trainTrees.add(tree);
}
return trainTrees;
}
public static void main(String[] args) {
String inputFile = args[0];
List<Tree<String>> trainTrees = loadTrees(inputFile);
GermanSharedTask grEx = new GermanSharedTask();
Grammar grammar = grEx.extractGrammar(trainTrees);
inputFile = "/Users/petrov/Data/german_st/tueba/tueba_tmp";
List<Tree<String>> testTrees = loadTrees(inputFile);
inputFile = "/Users/petrov/Data/german_st/tueba/data02.mrg";
List<Tree<String>> goldTrees = loadTrees(inputFile);
List<List<String>> goldPOS = new ArrayList<List<String>>(goldTrees.size());
for (Tree<String> t : goldTrees){
goldPOS.add(t.getPreTerminalYield());
}
grEx.checkGrammar(grammar, testTrees, goldTrees);
// grEx.labelTrees(grammar, testTrees, goldPOS);
}
}