/* This file is part of the Joshua Machine Translation System.
*
* Joshua is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1
* of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*/
package joshua.discriminative.monolingual_parser;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.JoshuaConfiguration;
import joshua.decoder.ff.FeatureFunction;
import joshua.decoder.ff.tm.GrammarFactory;
import joshua.util.FileUtility;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* this class implements:
*
* (1) parallel decoding: split the test file, initiate DecoderThread, wait and merge the decoding results
* (2) non-parallel decoding is a special case of parallel decoding
*
* @author Zhifei Li, <zhifei.work@gmail.com>
* @version $LastChangedDate: 2008-10-20 00:12:30 -0400 $
*/
public abstract class MonolingualDecoderFactory {
protected GrammarFactory[] p_grammar_factories = null;
protected boolean have_lm_model = false;
protected ArrayList<FeatureFunction> p_l_feat_functions = null;
protected ArrayList<Integer> l_default_nonterminals = null;
protected SymbolTable symbolTable = null;
protected MonolingualDecoderThread[] parallel_threads;
private static final Logger logger = Logger.getLogger(MonolingualDecoderFactory.class.getName());
public MonolingualDecoderFactory(GrammarFactory[] grammar_facories, boolean have_lm_model_, ArrayList<FeatureFunction> l_feat_functions,
ArrayList<Integer> l_default_nonterminals_, SymbolTable symbolTable){
this.p_grammar_factories = grammar_facories;
this.have_lm_model = have_lm_model_;
this.p_l_feat_functions = l_feat_functions;
this.l_default_nonterminals = l_default_nonterminals_;
this.symbolTable = symbolTable;
}
//decoderID starts from 1
public abstract MonolingualDecoderThread constructThread(int decoderID, String cur_test_file, int start_sent_id) throws IOException;
public abstract void mergeParallelDecodingResults() throws IOException;
public abstract void postProcess() throws IOException;
public void decodingTestSet(String test_file){
try{
//==== decode the sentences, maybe in parallel
if (JoshuaConfiguration.num_parallel_decoders == 1) {
MonolingualDecoderThread pdecoder = constructThread(1, test_file, 0);
pdecoder.decodeFile();//do not run *start*; so that we stay in the current main thread
} else {
if (JoshuaConfiguration.use_remote_lm_server) {// TODO
if (logger.isLoggable(Level.SEVERE))
logger.severe("You cannot run parallel decoder and remote lm server together");
System.exit(1);
}
runParallelDecoder(test_file);
}
postProcess();
} catch (IOException e) {
e.printStackTrace();
}
}
private void runParallelDecoder(String test_file)
throws IOException {
parallel_threads = new MonolingualDecoderThread[JoshuaConfiguration.num_parallel_decoders];
BufferedReader t_reader_test = FileUtility.getReadFileStream(test_file);
//==== compute number of lines for each decoder
int n_lines = 0; {
BufferedReader test_file_reader =
FileUtility.getReadFileStream(test_file);
while ((FileUtility.read_line_lzf(test_file_reader)) != null)
n_lines++;
test_file_reader.close();
}
double num_per_thread_double = n_lines * 1.0 / JoshuaConfiguration.num_parallel_decoders;
int num_per_thread_int = (int) num_per_thread_double;
if (logger.isLoggable(Level.INFO))
logger.info("num_per_file_double: " + num_per_thread_double
+ "num_per_file_int: " + num_per_thread_int);
//==== Initialize all threads and their input files
int decoder_i = 1;
String cur_test_file = JoshuaConfiguration.parallel_files_prefix + ".test." + decoder_i;
BufferedWriter t_writer_test = FileUtility.getWriteFileStream(cur_test_file);
int sent_id = 0;
int start_sent_id = sent_id;
String cn_sent;
while ((cn_sent = FileUtility.read_line_lzf(t_reader_test)) != null) {
sent_id++;
t_writer_test.write(cn_sent + "\n");
//make the Symbol table finalized before running multiple threads, this is to avoid synchronization among threads
{
String words[] = cn_sent.split("\\s+");
this.symbolTable.addTerminals(words); // TODO
}
// we will include all additional lines into last file
if (0 != sent_id
&& decoder_i < JoshuaConfiguration.num_parallel_decoders
&& sent_id % num_per_thread_int == 0
) {
//prepare current job
t_writer_test.flush();
t_writer_test.close();
MonolingualDecoderThread pdecoder = constructThread(decoder_i, cur_test_file, start_sent_id);
parallel_threads[decoder_i-1] = pdecoder;
// prepare next job
start_sent_id = sent_id;
decoder_i++;
cur_test_file = JoshuaConfiguration.parallel_files_prefix + ".test." + decoder_i;
t_writer_test = FileUtility.getWriteFileStream(cur_test_file);
}
}
//==== prepare the last job
t_writer_test.flush();
t_writer_test.close();
MonolingualDecoderThread pdecoder = constructThread(decoder_i, cur_test_file, start_sent_id);
parallel_threads[decoder_i-1] = pdecoder;
t_reader_test.close();
// End initializing threads and their files
//==== run all the jobs
for (int i = 0; i < parallel_threads.length; i++) {
if (logger.isLoggable(Level.INFO))
logger.info("=======start thread " + i);
parallel_threads[i].start();
}
//==== wait for the threads finish
for (int i = 0; i < parallel_threads.length; i++) {
try {
parallel_threads[i].join();
} catch (InterruptedException e) {
if (logger.isLoggable(Level.WARNING))
logger.warning("thread is interupted for server " + i);
}
}
//==== merge the nbest files, and remove tmp files
mergeParallelDecodingResults();
}
}