/** * */ package edu.berkeley.nlp.discPCFG; import java.io.FileInputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.util.ArrayList; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.zip.GZIPInputStream; import edu.berkeley.nlp.PCFGLA.ArrayParser; import edu.berkeley.nlp.PCFGLA.Binarization; import edu.berkeley.nlp.PCFGLA.ConstrainedTwoChartsParser; import edu.berkeley.nlp.PCFGLA.Grammar; import edu.berkeley.nlp.PCFGLA.GrammarMerger; import edu.berkeley.nlp.PCFGLA.Lexicon; import edu.berkeley.nlp.PCFGLA.ParserData; import edu.berkeley.nlp.PCFGLA.StateSetTreeList; import edu.berkeley.nlp.discPCFG.ParsingObjectiveFunction.Counts; import edu.berkeley.nlp.syntax.StateSet; import edu.berkeley.nlp.syntax.Tree; import edu.berkeley.nlp.math.SloppyMath; import edu.berkeley.nlp.util.Numberer; import edu.berkeley.nlp.util.ScalingTools; /** * @author petrov * */ public class ConditionalMerger { int nProcesses; String consBaseName; Grammar grammar; Lexicon lexicon; double mergingPercentage; String outFileName; StateSetTreeList[] trainingTrees; ExecutorService pool; Merger[] tasks; double[][] mergeWeights; class Merger implements Callable{ ArrayParser gParser; ConstrainedTwoChartsParser eParser; StateSetTreeList myTrees; String consName; int myID; int nCounts; boolean[][][][][] myConstraints; int unparsableTrees, incorrectLLTrees; double[][] mergeWeights; Merger(StateSetTreeList myT, String consN, int i, Grammar gr, Lexicon lex, double[][] mergeWeights){ this.consName = consN; this.myTrees = myT; this.myID = i; this.mergeWeights = mergeWeights; gParser = new ArrayParser(gr, lex); eParser = new ConstrainedTwoChartsParser(gr, lex, null); } private void loadConstraints(){ myConstraints = new boolean[myTrees.size()][][][][]; boolean[][][][][] curBlock = null; int block = 0; int i = 0; if (consName==null) return; for (int tree=0; tree<myTrees.size(); tree++){ if (curBlock == null || i >= curBlock.length){ int blockNumber = ((block*nProcesses)+myID); curBlock = loadData(consName+"-"+blockNumber+".data"); block++; i = 0; System.out.print("."); } eParser.projectConstraints(curBlock[i], false); myConstraints[tree] = curBlock[i]; i++; if (myConstraints[tree].length!=myTrees.get(tree).getYield().size()){ System.out.println("My ID: "+myID+", block: "+block+", sentence: "+i); System.out.println("Sentence length and constraints length do not match!"); myConstraints[tree] = null; } } } public double[][][] call() { if (myConstraints==null) loadConstraints(); double[][][] deltas = new double[grammar.numStates][mergeWeights[0].length][mergeWeights[0].length]; int i = -1; int block = 0; for (Tree<StateSet> stateSetTree : myTrees) { i++; boolean noSmoothing = true, debugOutput = false, hardCounts = false; gParser.doInsideOutsideScores(stateSetTree, noSmoothing, debugOutput); // parse the sentence List<StateSet> yield = stateSetTree.getYield(); List<String> sentence = new ArrayList<String>(yield.size()); for (StateSet el : yield){ sentence.add(el.getWord()); } boolean[][][][] cons = null; if (consName!=null){ cons = myConstraints[i]; if (cons.length != sentence.size()){ System.out.println("My ID: "+myID+", block: "+block+", sentence: "+i); System.out.println("Sentence length ("+sentence.size()+") and constraints length ("+cons.length+") do not match!"); System.exit(-1); } } eParser.doConstrainedInsideOutsideScores(yield,cons,noSmoothing,stateSetTree,null,false); eParser.tallyConditionalLoss(stateSetTree, deltas, mergeWeights); if (i%100==0) System.out.print("."); } System.out.print(" "+myID+" "); return deltas; } public boolean[][][][][] loadData(String fileName) { boolean[][][][][] data = null; try { FileInputStream fis = new FileInputStream(fileName); // Load from file GZIPInputStream gzis = new GZIPInputStream(fis); // Compressed ObjectInputStream in = new ObjectInputStream(gzis); // Load objects data = (boolean[][][][][])in.readObject(); // Read the mix of grammars in.close(); // And close the stream. } catch (IOException e) { System.out.println("IOException\n"+e); return null; } catch (ClassNotFoundException e) { System.out.println("Class not found!"); return null; } return data; } } /** * @param processes * @param consBaseName * @param trainingTrees */ public ConditionalMerger(int processes, String consBaseName, StateSetTreeList trainTrees, Grammar gr, Lexicon lex, double mergingPercentage, String outFileName) { this.nProcesses = processes; this.consBaseName = consBaseName; this.grammar = gr;//.copyGrammar(); this.lexicon = lex;//.copyLexicon(); this.mergingPercentage = mergingPercentage; this.outFileName = outFileName; int nTreesPerBlock = trainTrees.size()/processes; this.consBaseName = consBaseName; boolean[][][][][] tmp = edu.berkeley.nlp.PCFGLA.ParserConstrainer.loadData(consBaseName+"-0.data"); if (tmp!=null) nTreesPerBlock = tmp.length; // first compute the generative merging criterion mergeWeights = GrammarMerger.computeMergeWeights(grammar, lexicon,trainTrees); double[][][] deltas = GrammarMerger.computeDeltas(grammar, lexicon, mergeWeights, trainTrees); boolean[][][] mergeThesePairs = GrammarMerger.determineMergePairs(deltas,false,mergingPercentage,grammar); Grammar tmpGrammar = grammar.copyGrammar(true); Lexicon tmpLexicon = lexicon.copyLexicon(); tmpGrammar = GrammarMerger.doTheMerges(tmpGrammar, tmpLexicon, mergeThesePairs, mergeWeights); System.out.println("Generative merging criterion gives:"); GrammarMerger.printMergingStatistics(grammar, tmpGrammar); mergeWeights = GrammarMerger.computeMergeWeights(grammar, lexicon,trainTrees); // split the trees into chunks trainingTrees = new StateSetTreeList[nProcesses]; for (int i=0; i<nProcesses; i++){ trainingTrees[i] = new StateSetTreeList(); } int block = -1; int inBlock = 0; for (int i=0; i<trainTrees.size(); i++){ if (i%nTreesPerBlock==0) { block++; System.out.println(inBlock); inBlock = 0; } trainingTrees[block%nProcesses].add(trainTrees.get(i)); inBlock++; } trainTrees = null; pool = Executors.newFixedThreadPool(nProcesses);//CachedThreadPool(); tasks = new Merger[nProcesses]; for (int i=0; i<nProcesses; i++){ tasks[i] = new Merger(trainingTrees[i],consBaseName,i, grammar, lexicon, mergeWeights); } } public void mergeGrammarAndLexicon(){ System.out.print("Task: "); Future[] submits = new Future[nProcesses]; for (int i=0; i<nProcesses; i++){ Future submit = pool.submit(tasks[i]);//execute(tasks[i]); submits[i] = submit; } while (true) { boolean done = true; for (Future task : submits) { done &= task.isDone(); } if (done) break; } // accumulate double[][][] deltas = new double[grammar.numStates][mergeWeights[0].length][mergeWeights[0].length]; for (int i=0; i<nProcesses; i++){ double[][][] counts = null; try { counts = (double[][][]) submits[i].get(); } catch (ExecutionException e) { e.printStackTrace(); } catch (InterruptedException e) { e.printStackTrace(); } for (int a=0; a<deltas.length; a++){ for (int b=0; b<deltas[0].length; b++){ for (int c=0; c<deltas[0][0].length; c++){ deltas[a][b][c] += counts[a][b][c]; } } } } System.out.print(" done. "); System.out.println("Conditional merging criterion gives:"); boolean[][][] mergeThesePairs = GrammarMerger.determineMergePairs(deltas,false,mergingPercentage,grammar); Grammar newGrammar = GrammarMerger.doTheMerges(grammar, lexicon, mergeThesePairs, mergeWeights); GrammarMerger.printMergingStatistics(grammar, newGrammar); ParserData pData = new ParserData(lexicon, newGrammar, null, Numberer.getNumberers(), newGrammar.numSubStates, 1, 0, Binarization.RIGHT); System.out.println("Saving grammar to "+outFileName+"."); if (pData.Save(outFileName+"-merged")) System.out.println("Saving successful."); else System.out.println("Saving failed!"); } }