package edu.berkeley.nlp.PCFGLA;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import edu.berkeley.nlp.syntax.SpanTree;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.Numberer;
public class ParserConstrainer implements Callable {
StateSetTreeList stateSetTrees;
Grammar grammar;
Lexicon lexicon;
SpanPredictor spanPredictor;
String outBaseName;
double threshold;
String consName;
boolean keepGoldTreeAlive;
boolean useHierarchicalParser;
static int treesPerBlock;
int myID;
public ParserConstrainer(StateSetTreeList stateSetTrees, Grammar grammar,
Lexicon lexicon, SpanPredictor spanPredictor, String outBaseName,
double threshold, boolean keepGoldTreeAlive, int myID, String cons,
boolean useHierarchicalParser) {
this.stateSetTrees = stateSetTrees;
this.grammar = grammar;
this.lexicon = lexicon;
this.spanPredictor = spanPredictor;
this.outBaseName = outBaseName;
this.threshold = threshold;
this.consName = cons;
this.keepGoldTreeAlive = keepGoldTreeAlive;
this.myID = myID;
this.useHierarchicalParser = useHierarchicalParser;
}
public static void main(String[] args) {
OptionParser optParser = new OptionParser(ConditionalTrainer.Options.class);
ConditionalTrainer.Options opts = (ConditionalTrainer.Options) optParser
.parse(args, false);
// provide feedback on command-line arguments
System.out.println("Calling Constrainer with "
+ optParser.getPassedInOptions());
String path = opts.path;
// int lang = opts.lang;
System.out.println("Loading trees from " + path + " and using language "
+ opts.treebank);
String testSetString = opts.section;
boolean devTestSet = testSetString.equals("dev");
boolean finalTestSet = testSetString.equals("final");
boolean trainTestSet = testSetString.equals("train");
System.out.println(" using " + testSetString + " test set");
Corpus corpus = new Corpus(path, opts.treebank,
opts.trainingFractionToKeep, !trainTestSet);
List<Tree<String>> testTrees = null;
if (devTestSet)
testTrees = corpus.getDevTestingTrees();
if (finalTestSet)
testTrees = corpus.getFinalTestingTrees();
if (trainTestSet)
testTrees = corpus.getTrainTrees();
boolean manualAnnotation = false;
testTrees = Corpus.binarizeAndFilterTrees(testTrees,
opts.verticalMarkovization, opts.horizontalMarkovization, opts.maxL,
opts.binarization, manualAnnotation, GrammarTrainer.VERBOSE,
opts.markUnaryParents);
if (!devTestSet && opts.collapseUnaries)
System.out.println("Collpasing unary chains.");
testTrees = Corpus.filterTreesForConditional(testTrees,
opts.filterAllUnaries, opts.filterStupidFrickinWHNP, !devTestSet
&& opts.collapseUnaries);
boolean keepGoldAlive = opts.keepGoldTreeAlive || trainTestSet;
String inFileName = opts.inFile;
System.out.println("Loading grammar from " + inFileName + ".");
ParserData pData = ParserData.Load(inFileName);
if (pData == null) {
System.out
.println("Failed to load grammar from file " + inFileName + ".");
System.exit(1);
}
Grammar grammar = pData.getGrammar();
grammar.splitRules();
Lexicon lexicon = pData.getLexicon();
lexicon.explicitlyComputeScores(grammar.finalLevel);
SpanPredictor spanPredictor = pData.getSpanPredictor();
if (opts.flattenParameters != 1.0) {
System.out.println("Flattening parameters with exponent "
+ opts.flattenParameters + " to reduce overconfidence.");
grammar.removeUnlikelyRules(0, opts.flattenParameters);
lexicon.removeUnlikelyTags(0, opts.flattenParameters);
}
Numberer.setNumberers(pData.getNumbs());
Numberer tagNumberer = Numberer.getGlobalNumberer("tags");
StateSetTreeList stateSetTrees = new StateSetTreeList(testTrees,
grammar.numSubStates, false, tagNumberer);
testTrees = null;
String outBaseName = opts.outFileName;
double threshold = Math.exp(opts.logT);
int nChunks = opts.nChunks;
int nTrees = stateSetTrees.size();
System.out.println("There are " + nTrees + " trees in this set.");
treesPerBlock = (int) Math.ceil(nTrees / (double) nChunks);
System.out.println("Will store " + treesPerBlock
+ " constraints per file, in " + nChunks + " files.");
System.out.println("All states with posterior probability below "
+ threshold + " will be pruned.");
if (keepGoldAlive)
System.out.println("But the gold tree will survive!");
System.out.println("The constraints will be written to " + outBaseName
+ ".");
// split the trees into chunks
StateSetTreeList[] trainingTrees = new StateSetTreeList[nChunks];
for (int i = 0; i < nChunks; i++) {
trainingTrees[i] = new StateSetTreeList();
}
int block = -1;
int inBlock = 0;
for (int i = 0; i < nTrees; i++) {
if (i % treesPerBlock == 0) {
block++;
// System.out.println(inBlock);
inBlock = 0;
}
trainingTrees[block].add(stateSetTrees.get(i));
inBlock++;
}
for (int i = 0; i < nChunks; i++) {
System.out.println("Process " + i + " has " + trainingTrees[i].size()
+ " trees.");
}
stateSetTrees = null;
ExecutorService pool = Executors.newFixedThreadPool(nChunks);
Future[] submits = new Future[nChunks];
ParserConstrainer thisThreadConstrainer = null;
if (nChunks == 1)
thisThreadConstrainer = new ParserConstrainer(trainingTrees[0], grammar,
lexicon, spanPredictor, outBaseName, threshold, keepGoldAlive, 0,
opts.cons, opts.hierarchicalChart);
else {
for (int i = 0; i < nChunks; i++) {
ParserConstrainer constrainer = new ParserConstrainer(trainingTrees[i],
grammar, lexicon, spanPredictor, outBaseName, threshold,
keepGoldAlive, i, opts.cons, opts.hierarchicalChart);
submits[i] = pool.submit(constrainer);
}
while (true) {
boolean done = true;
for (Future task : submits) {
done &= task.isDone();
}
if (done)
break;
}
// pool.shutdown();
}
try {
PrintWriter outputData = (opts.outputLog == null) ? new PrintWriter(
new OutputStreamWriter(System.out))
: new PrintWriter(new OutputStreamWriter(new FileOutputStream(
opts.outputLog), "UTF-8"), true);
for (int i = 0; i < nChunks; i++) {
StringBuilder sb = null;
if (nChunks == 1) {
sb = thisThreadConstrainer.call();
} else {
sb = (StringBuilder) submits[i].get();
}
outputData.print(sb.toString());
}
if (opts.outputLog != null) {
outputData.flush();
outputData.close();
}
} catch (ExecutionException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (UnsupportedEncodingException e1) {
// TODO Auto-generated catch block
e1.printStackTrace();
} catch (FileNotFoundException e1) {
// TODO Auto-generated catch block
e1.printStackTrace();
}
System.out.println("Done computing constraints.");
}
/**
*
* @param opts
*/
public StringBuilder call() {
ConstrainedTwoChartsParser parser = (grammar instanceof HierarchicalAdaptiveGrammar) ? new ConstrainedHierarchicalTwoChartParser(
grammar, lexicon, spanPredictor, grammar.finalLevel)
: new ConstrainedTwoChartsParser(grammar, lexicon, spanPredictor);
StringBuilder sb = new StringBuilder();
int recentHistoryIndex = 0;
// int sentenceNumber = 1;
boolean[][][][][] recentHistory = new boolean[treesPerBlock][][][][];
boolean[][][][][] myConstraints = null;
boolean useCons = consName != null;
if (useCons)
myConstraints = loadData(consName + "-" + myID + ".data");
boolean[][][][] cons = null;
for (Tree<StateSet> testTree : stateSetTrees) {
List<StateSet> yield = testTree.getYield();
List<String> testSentence = new ArrayList<String>(yield.size());
for (StateSet el : yield) {
testSentence.add(el.getWord());
}
sb.append("\n" + (myID * treesPerBlock + recentHistoryIndex + 1)
+ ". Length " + testSentence.size());
if (useCons) {
parser.projectConstraints(myConstraints[recentHistoryIndex], false);
cons = myConstraints[recentHistoryIndex];
}
Tree<StateSet> sTree = null;
if (keepGoldTreeAlive) {
// System.out.println("keeping gold tree alive");
sTree = testTree;
}
boolean[][][][] possibleStates = parser.getPossibleStates(testSentence,
sTree, threshold, cons, sb);
assert sTree == null || contains(possibleStates, sTree);
if (useCons)
myConstraints[recentHistoryIndex] = null;
recentHistory[recentHistoryIndex++] = possibleStates;
if (recentHistoryIndex % 1000 == 0)
System.out.print(".");
// sentenceNumber++;
// if (recentHistoryIndex>0 && (recentHistoryIndex % treesPerBlock == 0))
// {
// String fileName = outBaseName+"-"+blockIndex+".data";
// saveData(recentHistory, fileName);
// blockIndex++;
// if (useCons && sentenceNumber<nTrees) myConstraints =
// loadData(consName+"-"+blockIndex+".data");
// recentHistory = new boolean[treesPerBlock][][][][];
// recentHistoryIndex = 0;
// }
}
// if (recentHistoryIndex!=0) {
String fileName = outBaseName + "-" + myID + ".data";
saveData(recentHistory, fileName);
// }
return sb;
}
/**
* @param possibleStates
* @param tree
* @return
*/
private boolean contains(boolean[][][][] possibleStates, Tree<StateSet> tree) {
boolean[] bs = possibleStates[tree.getLabel().from][tree.getLabel().to][tree
.getLabel().getState()];
if (tree.isLeaf())
return true;
if (bs == null) {
assert false;
return false;
}
boolean hasTrue = false;
for (boolean b : bs)
hasTrue |= b;
if (!hasTrue) {
assert false;
return false;
}
boolean allThere = true;
for (Tree<StateSet> child : tree.getChildren()) {
allThere &= contains(possibleStates, child);
}
return allThere;
}
public static boolean saveData(boolean[][][][][] data, String fileName) {
try {
// here's some code from online; it looks good and gzips the output!
// there's a whole explanation at
// http://www.ecst.csuchico.edu/~amk/foo/advjava/notes/serial.html
// Create the necessary output streams to save the scribble.
FileOutputStream fos = new FileOutputStream(fileName); // Save to file
GZIPOutputStream gzos = new GZIPOutputStream(fos); // Compressed
ObjectOutputStream out = new ObjectOutputStream(gzos); // Save objects
out.writeObject(data); // Write the mix of grammars
out.flush(); // Always flush the output.
out.close(); // And close the stream.
gzos.close();
fos.close();
} catch (IOException e) {
System.out.println("IOException: " + e);
return false;
}
return true;
}
public static boolean isGoldReachable(SpanTree<String> gold,
List[][] possibleStates, Numberer tagNumberer) {
boolean reachable = true;
reachable = possibleStates[gold.getStart()][gold.getEnd()]
.contains(tagNumberer.number(gold.getLabel()));
if (reachable && (!gold.isLeaf())) {
for (SpanTree<String> child : gold.getChildren()) {
reachable = isGoldReachable(child, possibleStates, tagNumberer);
if (!reachable)
return false;
}
}
if (!reachable) {
System.out.println("Cannot reach state " + gold.getLabel()
+ " spanning from " + gold.getStart() + " to " + gold.getEnd() + ".");
}
return reachable;
}
public static SpanTree<String> convertToSpanTree(Tree<String> tree) {
if (tree.isPreTerminal()) {
return new SpanTree<String>(tree.getLabel());
}
if (tree.getChildren().size() > 2)
System.out.println("Binarize properly first!");
SpanTree<String> spanTree = new SpanTree<String>(tree.getLabel());
List<SpanTree<String>> spanChildren = new ArrayList<SpanTree<String>>();
for (Tree<String> child : tree.getChildren()) {
SpanTree<String> spanChild = convertToSpanTree(child);
spanChildren.add(spanChild);
}
spanTree.setChildren(spanChildren);
return spanTree;
}
public static boolean[][][][][] loadData(String fileName) {
boolean[][][][][] data = null;
try {
FileInputStream fis = new FileInputStream(fileName); // Load from file
GZIPInputStream gzis = new GZIPInputStream(fis); // Compressed
ObjectInputStream in = new ObjectInputStream(gzis); // Load objects
data = (boolean[][][][][]) in.readObject(); // Read the mix of grammars
in.close(); // And close the stream.
gzis.close();
fis.close();
} catch (IOException e) {
System.out.println("IOException\n" + e);
return null;
} catch (ClassNotFoundException e) {
System.out.println("Class not found!");
return null;
}
return data;
}
}