/* This file is part of the Joshua Machine Translation System. * * Joshua is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 * of the License, or (at your option) any later version. * * This library is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this library; if not, write to the Free * Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, * MA 02111-1307 USA */ package joshua.decoder.ff.lm.distributed_lm; import joshua.decoder.Support; import joshua.util.SocketUtility; import java.util.ArrayList; import java.util.HashMap; import java.util.logging.Level; import java.util.logging.Logger; /** * this class implement * (1) The client side when using multiple LMServers * * @author Zhifei Li, <zhifei.work@gmail.com> * @version $LastChangedDate: 2009-12-29 13:58:42 -0600 (Tue, 29 Dec 2009) $ */ public class LMClientMultiServer extends LMClient { private static final Logger logger = Logger.getLogger(LMClientMultiServer.class.getName()); public static SocketUtility.ClientConnection[] l_clients = null; public static double[] probs = null; public static double[] weights = null; public static LMThread[] l_thread_handlers = null; public static int num_lm_servers = 1; public static String g_packet = null; public static long delayMillis = 5000; //5 seconds HashMap<String,Double> request_cache = new HashMap<String,Double>(); //cmd with result int cache_size_limit = 3000000; /* Performance considerations: we do not want to initiate new threads for each specific n-gram request. Instead, * we want to have several threads always sitting there, and wait for n-gram requests. This is also true the socket we try to maintain. * */ //thread communcation static boolean[] response_ready;//set by the children-thread, read by the main thread static boolean request_ready;//set by the main thread, read by children-threads static Thread p_main_thread; static boolean should_finish = false; static long g_time_interval = 5000; //5 seconds //stat static int g_n_request = 0; static int g_n_cache_hit = 0; public LMClientMultiServer( String[] hostnames, int[] ports, double[] weights_, int n_servers ) { LMClientMultiServer.p_main_thread = Thread.currentThread(); LMClientMultiServer.num_lm_servers = n_servers; LMClientMultiServer.l_clients = new SocketUtility.ClientConnection[n_servers]; LMClientMultiServer.probs = new double[n_servers]; LMClientMultiServer.weights = new double[n_servers]; LMClientMultiServer.l_thread_handlers = new LMThread[n_servers]; LMClientMultiServer.response_ready = new boolean[n_servers]; LMClientMultiServer.request_ready = false; for(int i = 0; i < n_servers; i++) { l_clients[i] = SocketUtility.open_connection_client(hostnames[i], ports[i]); LMClientMultiServer.weights[i] = weights_[i]; //thread LMClientMultiServer.response_ready[i] = false; LMClientMultiServer.l_thread_handlers[i] = new LMThread(i); LMClientMultiServer.l_thread_handlers[i].start(); } } public void close_client() { //TODO //TODO close socket //END all the threads should_finish = true; for (int i = 0; i < num_lm_servers; i++) { l_clients[i].close(); l_thread_handlers[i].interrupt(); } } //cmd: prob order wrd1 wrd2 ... public double get_prob(ArrayList<Integer> ngram, int order) { return get_prob(Support.subIntArray(ngram, 0, ngram.size()), order); } //cmd: prob order wrd1 wrd2 ... public double get_prob(int[] ngram, int order) { String packet = encode_packet("prob", order, ngram); return exe_request(packet); } //cmd: prob order wrd1 wrd2 ... public double get_prob_backoff_state(int[] ngram, int n_additional_bow) { throw new RuntimeException("call get_prob_backoff_state in lmclient, must exit"); //double res=0.0; //String packet= encode_packet("problbo", n_additional_bow, ngram); //String cmd_res = exe_request(packet); //res = Double.parseDouble(cmd_res); //return res; } public int[] get_left_euqi_state(int[] original_state_wrds, int order, double[] cost) { throw new RuntimeException("call get_left_euqi_state in lmclient, must exit"); //double res=0.0; //String packet= encode_packet("leftstate", order, original_state_wrds); //String cmd_res = exe_request(packet); //res = Double.parseDouble(cmd_res); //return null;//big bug } public int[] get_right_euqi_state(int[] original_state, int order) { throw new RuntimeException("call get_right_euqi_state in lmclient, must exit"); //double res=0.0; //String packet= encode_packet("rightstate", order, original_state); //String cmd_res = exe_request(packet); //res = Double.parseDouble(cmd_res); //return null;//big bug } private String encode_packet(String cmd, int num, int[] words) { StringBuffer packet = new StringBuffer(); packet.append(cmd); packet.append(' '); packet.append(num); for (int i = 0; i < words.length; i++) { packet.append(' '); packet.append(words[i]); } return packet.toString(); } /* TODO Possibly remove - this method is never called. private String encode_packet(String cmd, int num, ArrayList words) { StringBuffer packet = new StringBuffer(); packet.append(cmd); packet.append(" "); packet.append(num); for (int i = 0; i < words.size(); i++) { packet.append(" "); packet.append(words.get(i)); } return packet.toString(); } */ //TODO: synchronization problem to request_cache, if we use more than one LMClientMultiServer private double exe_request(String packet) { //search cache Double cmd_res = (Double)request_cache.get(packet); g_n_request++; //cache fail if (null == cmd_res) { //exe the request cmd_res = process_request_parallel(packet); //update cache if (request_cache.size() > cache_size_limit) { request_cache.clear(); } request_cache.put(packet, cmd_res); } else { g_n_cache_hit++; } if (logger.isLoggable(Level.FINE) && g_n_request % 50000 == 0) { logger.fine( "n_requests: " + g_n_request + "; n_cache_hits: " + g_n_cache_hit + "; cache size= " + request_cache.size() + "; hit rate= " + g_n_cache_hit * 1.0 / g_n_request ); } return cmd_res; } // This is the function that application specific private double process_request_parallel(String packet) { g_packet = packet; request_ready = true; //##### init the threads for (int i = 0; i < num_lm_servers; i++) { probs[i] = 0.0; //reset to zero response_ready[i] = false; l_thread_handlers[i].interrupt(); } //##### wait until all are finished boolean all_finished = false; while (! all_finished) { try { Thread.sleep(g_time_interval); //sleep forever until get interrupted, big bug } catch (InterruptedException e) { //at least a new one is finished or timer expired all_finished = true; for (int i = 0; i < num_lm_servers; i++) { if (! response_ready[i]) { all_finished = false; break; } } } } request_ready = false; //#### linear interpolate the results, all threads are done double sum = 0; for (int i = 0; i < num_lm_servers; i++) { sum += probs[i]*weights[i]; //System.out.println("prob "+i+" is " + probs[i] + " weight is "+weights[i]+" sum is "+sum); } //System.out.println("sum is " + sum); return sum; } //a thread to a single lm server private static class LMThread extends Thread { //TODO: if the thread is dead due to exception, we should restart the thread int pos;//remember where i should write back the results public LMThread(int p) { pos = p; } public void run() { while (true) { try { Thread.sleep(g_time_interval);//sleep forever until get interrupted } catch (InterruptedException e) { //three possibilities: expired, request_ready, or should_finish if (request_ready) { String cmd_res = l_clients[pos].exe_request(g_packet); if (null == cmd_res) { throw new RuntimeException("cmd_res is null, must exit"); } else { probs[pos] = Double.parseDouble(cmd_res); response_ready[pos] = true; p_main_thread.interrupt(); } } if (should_finish) { break; } } // end catch } // end while true } } }