/* This file is part of the Joshua Machine Translation System. * * Joshua is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 * of the License, or (at your option) any later version. * * This library is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this library; if not, write to the Free * Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, * MA 02111-1307 USA */ package joshua.discriminative.monolingual_parser; import joshua.corpus.vocab.BuildinSymbol; import joshua.corpus.vocab.SymbolTable; import joshua.decoder.JoshuaConfiguration; import joshua.decoder.ff.FeatureFunction; import joshua.decoder.ff.ArityPhrasePenaltyFF; import joshua.decoder.ff.PhraseModelFF; import joshua.decoder.ff.WordPenaltyFF; import joshua.decoder.ff.SourceLatticeArcLogPFF; import joshua.decoder.ff.tm.GrammarFactory; import joshua.util.FileUtility; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.IOException; import java.util.ArrayList; import java.util.logging.Level; import java.util.logging.Logger; /** * this class implements: * (1) mainly initialize, and control the interaction with JoshuaConfiguration and DecoderThread * * @author Zhifei Li, <zhifei.work@gmail.com> * @version $LastChangedDate: 2009-03-09 23:49:09 -0400 $ */ public class MonolingualJoshuaDecoder { private MonolingualDecoderFactory p_decoder_factory; // pointer to the main thread of decoding private GrammarFactory[] grammarFactories; private ArrayList<FeatureFunction> featureFunctions; private ArrayList<Integer> l_default_nonterminals; private SymbolTable p_symbolTable; boolean runEM =true; private static final Logger logger = Logger.getLogger(MonolingualJoshuaDecoder.class.getName()); //=============================================================== // Main //=============================================================== public static void main(String[] args) throws IOException { logger.finest("Starting decoder"); long start = System.currentTimeMillis(); if (args.length != 3 && args.length != 4) { System.out.println("Usage: java joshua.decoder.Decoder config_file test_file outfile (oracle_file)"); System.out.println("num of args is "+ args.length); for (int i = 0; i < args.length; i++) { System.out.println("arg is: " + args[i]); } System.exit(1); } String config_file = args[0].trim(); String test_file = args[1].trim(); String nbest_file = args[2].trim(); MonolingualJoshuaDecoder p_decoder = new MonolingualJoshuaDecoder(); //============== Step-1: initialize the decoder ============== p_decoder.initialize(config_file); //============== statistics double t_sec = (System.currentTimeMillis() - start) / 1000; if (logger.isLoggable(Level.INFO)) logger.info("before translation, loaddingtime is " + t_sec); //#============== Step-2: Decoding ============== p_decoder.decodingTestSet(test_file, nbest_file); //============== Step-3: clean up ============== p_decoder.cleanUp(); t_sec = (System.currentTimeMillis() - start) / 1000; if (logger.isLoggable(Level.INFO)) logger.info("Total running time is " + t_sec); } // end main() //=============================================================== //##### procedures: read config, init lm, init sym tbl, init models, read lm, read tm public void initialize(String config_file) { try { //=== read config file JoshuaConfiguration.readConfigFile(config_file); //=== initialize symbol table initSymbolTbl(); //TODO ##### add default non-terminals setDefaultNonTerminals(JoshuaConfiguration.default_non_terminal); //=== load TM grammar if (JoshuaConfiguration.tm_file != null) { initializeTranslationGrammars(JoshuaConfiguration.tm_file); } else { throw new RuntimeException("No translation grammargrammar was specified."); } //=== initialize the models(need to read config file again) featureFunctions = initializeFeatureFunctions(p_symbolTable, config_file); } catch (IOException e) { e.printStackTrace(); } } public void decodingTestSet(String test_file, String nbest_file) { // create factory if(runEM){//run EM training p_decoder_factory = new EMDecoderFactory( this.grammarFactories, false,//no LM model this.featureFunctions, this.l_default_nonterminals, this.p_symbolTable, nbest_file);//nbest_file *is* outGrammarFile }else{//regular decoding p_decoder_factory = new NbestDecoderFactory( this.grammarFactories, false,//no LM model this.featureFunctions, this.l_default_nonterminals, this.p_symbolTable, nbest_file); } // BUG: this works for Batch grammar only; not for sentence-specific grammars for (GrammarFactory grammarFactory : this.grammarFactories) { grammarFactory.getGrammarForSentence(null) .sortGrammar(this.featureFunctions); } p_decoder_factory.decodingTestSet(test_file); } public void cleanUp() { //TODO //p_lm_grammar.end_lm_grammar(); //end the threads } public static ArrayList<FeatureFunction> initializeFeatureFunctions(SymbolTable psymbolTable, String config_file) throws IOException { BufferedReader t_reader_config = FileUtility.getReadFileStream(config_file); ArrayList<FeatureFunction> l_models = new ArrayList<FeatureFunction>(); String line; while ((line = FileUtility.read_line_lzf(t_reader_config)) != null) { line = line.trim(); if (line.matches("^\\s*(?:\\#.*)?$")) { // ignore empty lines or lines commented out continue; } if (line.indexOf("=") == -1) { //ignore lines with "=" String[] fds = line.split("\\s+"); if (0 == fds[0].compareTo("latticecost") && fds.length == 2) { double weight = Double.parseDouble(fds[1].trim()); l_models.add(new SourceLatticeArcLogPFF(l_models.size(), weight)); if (logger.isLoggable(Level.FINEST)) logger.finest( String.format("Line: %s\nAdd Source lattice cost, weight: %.3f", weight)); } else if (0 == fds[0].compareTo("phrasemodel") && fds.length == 4) { // phrasemodel owner column(0-indexed) weight int owner = psymbolTable.addTerminal(fds[1]); int column = (new Integer(fds[2].trim())).intValue(); double weight = (new Double(fds[3].trim())).doubleValue(); l_models.add(new PhraseModelFF(l_models.size(), weight, owner, column)); if (logger.isLoggable(Level.FINEST)) logger.finest(String.format("Process Line: %s\nAdd PhraseModel, owner: %s; column: %d; weight: %.3f", line, owner, column, weight)); } else if (0 == fds[0].compareTo("arityphrasepenalty") && fds.length == 5){//arityphrasepenalty owner start_arity end_arity weight int owner = psymbolTable.addTerminal(fds[1]); int start_arity = (new Integer(fds[2].trim())).intValue(); int end_arity = (new Integer(fds[3].trim())).intValue(); double weight = (new Double(fds[4].trim())).doubleValue(); l_models.add(new ArityPhrasePenaltyFF(l_models.size(), weight, owner, start_arity, end_arity)); if (logger.isLoggable(Level.FINEST)) logger.finest(String.format("Process Line: %s\nAdd ArityPhrasePenalty, owner: %s; start_arity: %d; end_arity: %d; weight: %.3f",line, owner, start_arity, end_arity, weight)); } else if (0 == fds[0].compareTo("wordpenalty") && fds.length == 2) { // wordpenalty weight double weight = (new Double(fds[1].trim())).doubleValue(); l_models.add(new WordPenaltyFF(l_models.size(), weight)); if (logger.isLoggable(Level.FINEST)) logger.finest(String.format("Process Line: %s\nAdd WordPenalty, weight: %.3f", line, weight)); } else { if (logger.isLoggable(Level.SEVERE)) logger.severe("Wrong config line: " + line); System.exit(1); } } } t_reader_config.close(); return l_models; } private void setDefaultNonTerminals(String default_non_terminal) { //TODO ##### add default non-terminals l_default_nonterminals = new ArrayList<Integer>(); l_default_nonterminals.add(this.p_symbolTable.addNonterminal(default_non_terminal)); } private void initSymbolTbl() throws IOException { this.p_symbolTable = new BuildinSymbol(null); } // This depends (invisibly) on the language model in order to do pruning of the TM at load time. private void initializeTranslationGrammars(String tm_file) throws IOException { grammarFactories = new GrammarFactory[2]; // Glue Grammar GrammarFactory glueGrammar = //new MemoryBasedBatchGrammarWithPrune( new MonolingualGrammar( "monolingual", p_symbolTable, JoshuaConfiguration.glue_file, JoshuaConfiguration.phrase_owner, JoshuaConfiguration.default_non_terminal, JoshuaConfiguration.goal_symbol, -1, runEM); grammarFactories[0] = glueGrammar; // Regular TM Grammar GrammarFactory regularGrammar = //new MemoryBasedBatchGrammarWithPrune( new MonolingualGrammar( "monolingual", p_symbolTable, tm_file, JoshuaConfiguration.phrase_owner, JoshuaConfiguration.default_non_terminal, JoshuaConfiguration.goal_symbol, JoshuaConfiguration.span_limit, runEM); grammarFactories[1] = regularGrammar; //TODO if suffix-array: call SAGrammarFactory(SuffixArray sourceSuffixArray, CorpusArray targetCorpus, AlignmentArray alignments, LexicalProbabilities lexProbs, int maxPhraseSpan, int maxPhraseLength, int maxNonterminals, int spanLimit) { } public void writeConfigFile(double[] new_weights, String template, String file_to_write) { try { BufferedReader t_reader_config = FileUtility.getReadFileStream(template); BufferedWriter t_writer_config = FileUtility.getWriteFileStream(file_to_write); String line; int feat_id = 0; while ((line = FileUtility.read_line_lzf(t_reader_config)) != null) { line = line.trim(); if (line.matches("^\\s*(?:\\#.*)?$") || line.indexOf("=") != -1) { //comment, empty line, or parameter lines: just copy t_writer_config.write(line + "\n"); } else { //models: replace the weight String[] fds = line.split("\\s+"); StringBuffer new_line = new StringBuffer(); if (! fds[fds.length-1].matches("^[\\d\\.\\-\\+]+")) { System.out.println("last field is not a number, must be wrong; the field is: " + fds[fds.length-1]); System.exit(1); } for (int i = 0; i < fds.length-1; i++) { new_line.append(fds[i]); new_line.append(" "); } new_line.append(new_weights[feat_id++]); t_writer_config.write(new_line.toString() + "\n"); } } if (feat_id != new_weights.length) { System.out.println("number of models does not match number of weights, must be wrong"); System.exit(1); } t_reader_config.close(); t_writer_config.close(); } catch (IOException e) { e.printStackTrace(); } } }