package joshua.discriminative.monolingual_parser; import java.io.IOException; import java.util.ArrayList; import joshua.corpus.vocab.SymbolTable; import joshua.decoder.ff.FeatureFunction; import joshua.decoder.ff.tm.Grammar; import joshua.decoder.ff.tm.GrammarFactory; import joshua.decoder.ff.tm.Rule; import joshua.decoder.ff.tm.RuleCollection; import joshua.decoder.ff.tm.Trie; /** This class should have a way to store the accumulated count, normalize the count, * and assign values to the paramemters * */ public class EMDecoderFactory extends MonolingualDecoderFactory { private float normConstant = 0;//real semiring; the norm constant for LHS private double dataLogProb=0; private String outGrammarFile; public EMDecoderFactory(GrammarFactory[] grammar_facories, boolean have_lm_model_, ArrayList<FeatureFunction> l_feat_functions, ArrayList<Integer> l_default_nonterminals_, SymbolTable symbolTable, String outGrammarFile_) { super(grammar_facories, have_lm_model_, l_feat_functions, l_default_nonterminals_, symbolTable); outGrammarFile = outGrammarFile_; } @Override public MonolingualDecoderThread constructThread(int decoderID, String cur_test_file, int start_sent_id) throws IOException { MonolingualDecoderThread pdecoder = new EMDecoderThread( this, this.p_grammar_factories, this.have_lm_model, this.p_l_feat_functions, this.l_default_nonterminals, this.symbolTable, cur_test_file, start_sent_id ); return pdecoder; } @Override public void mergeParallelDecodingResults() throws IOException { //do nothing } @Override public void postProcess() throws IOException { //== run M step here reEstmateGrammars(); System.out.println("======== Data log prob is " + dataLogProb); } public void accumulateDataLogProb(double exampleLogProb){ //data likilihood is a product of each example, so it is a sum in log domain dataLogProb += exampleLogProb; } public void incrementRulePosteriorProb(Rule rl, double prob){ MonolingualGrammar.incrementRulePosteriorProb(rl, prob); } private void reEstmateGrammars(){ for (GrammarFactory grammarFactory : this.p_grammar_factories) { Grammar bathGrammar = grammarFactory.getGrammarForSentence(null); accumulatePosteriorCountInGrammar(bathGrammar); normalizePosteriorCountInGrammar(bathGrammar); //TODO: this will *correctly* write the regular grammar, instead of the GLUE grammar, but we should avoid this ((MonolingualGrammar) bathGrammar).writeGrammarOnDisk(outGrammarFile, this.symbolTable); bathGrammar.sortGrammar(this.p_l_feat_functions); } } private void accumulatePosteriorCountInGrammar(Grammar grammar) { normConstant =0; accumulatePosteriorCountInTrie(grammar.getTrieRoot()); System.out.println("normConstant for a grammar is " + normConstant); } private void accumulatePosteriorCountInTrie(Trie trie) { if(trie.hasRules()){ RuleCollection rlCollection = trie.getRules(); for(Rule rl : rlCollection.getSortedRules()){ //TODO: LHS specific normConstant += MonolingualGrammar.getRulePosteriorProb(rl); } } if (trie.hasExtensions()) { Object[] tem = trie.getExtensions().toArray(); for (int i = 0; i < tem.length; i++) { accumulatePosteriorCountInTrie((Trie)tem[i]); } } } private void normalizePosteriorCountInGrammar(Grammar grammar) { normalizePosteriorCountInTrie(grammar, grammar.getTrieRoot()); } private void normalizePosteriorCountInTrie(Grammar grammar, Trie trie) { RuleCollection rlCollection = trie.getRules(); if(trie.hasRules()){ for(Rule rl : rlCollection.getSortedRules()){ //TODO: LHS specific double oldVal = MonolingualGrammar.getRuleNormalizedCost(rl); //==add-lambda smoothing float smoothingConstant = 0.1f; float prob = (MonolingualGrammar.getRulePosteriorProb(rl)+smoothingConstant)/(normConstant+smoothingConstant*grammar.getNumRules()); MonolingualGrammar.setRuleNormalizedCost(rl, prob); double newVal = MonolingualGrammar.getRuleNormalizedCost(rl); if( symbolTable.getWord(rl.getLHS()).compareTo("S")==0 ){ System.out.println("count: "+ MonolingualGrammar.getRulePosteriorProb(rl) + "; norm: " + normConstant + "; old: " + oldVal + "; new: " + newVal); System.out.println(rl.toString(symbolTable)); } MonolingualGrammar.resetRulePosteriorProb(rl);//reset to zero } } if (trie.hasExtensions()) { Object[] tem = trie.getExtensions().toArray(); for (int i = 0; i < tem.length; i++) { normalizePosteriorCountInTrie(grammar, (Trie)tem[i]); } } } }