import java.math.*; import java.util.*; import java.io.*; import java.text.DecimalFormat; public class SimpleDecoder { private static DecimalFormat f3 = new DecimalFormat("###0.000"); public static void main(String[] args) throws Exception { String configFileName = args[0]; String sourceFileName = args[1]; String outputFileName = args[2]; int numParams = 3; int numSentences = countLines(sourceFileName); double[] weights = new double[numParams]; String candsFileName = ""; int cps = 0; int N = 0; InputStream inStream = new FileInputStream(new File(configFileName)); BufferedReader inFile = new BufferedReader(new InputStreamReader(inStream, "utf8")); String line = inFile.readLine(); while (line != null) { if (line.startsWith("cands_file")) { candsFileName = (line.substring(line.indexOf("=")+1)).trim(); } else if (line.startsWith("cands_per_sen")) { cps = Integer.parseInt((line.substring(line.indexOf("=")+1)).trim()); } else if (line.startsWith("top_n")) { N = Integer.parseInt((line.substring(line.indexOf("=")+1)).trim()); } else if (line.startsWith("LM")) { weights[0] = Double.parseDouble((line.substring(2+1)).trim()); } else if (line.startsWith("first model")) { weights[1] = Double.parseDouble((line.substring(11+1)).trim()); } else if (line.startsWith("second model")) { weights[2] = Double.parseDouble((line.substring(12+1)).trim()); } else if (line.startsWith("#")) { } else if (line.length() > 0) { println("Wrong format in config file."); System.exit(1); } line = inFile.readLine(); } inFile.close(); String[][] candidates = new String[numSentences][cps]; double[][][] features = new double[numSentences][cps][numParams]; inStream = new FileInputStream(new File(candsFileName)); inFile = new BufferedReader(new InputStreamReader(inStream, "utf8")); for (int i = 0; i < numSentences; ++i) { for (int n = 0; n < cps; ++n) { // read the nth candidate for the ith sentence line = inFile.readLine(); /* line format: i ||| words of candidate translation . ||| feat-1_val feat-2_val ... feat-numParams_val .* */ line = (line.substring(line.indexOf("|||")+3)).trim(); // get rid of initial text String candidate_str = (line.substring(0,line.indexOf("|||"))).trim(); String feats_str = (line.substring(line.indexOf("|||")+3)).trim(); // get rid of candidate string int junk_i = feats_str.indexOf("|||"); if (junk_i >= 0) { feats_str = (feats_str.substring(0,junk_i)).trim(); } String[] featVal_str = feats_str.split("\\s+"); candidates[i][n] = candidate_str; for (int c = 0; c < numParams; ++c) { features[i][n][c] = Double.parseDouble(featVal_str[c]); } } } double[][] scores = new double[numSentences][cps]; for (int i = 0; i < numSentences; ++i) { for (int n = 0; n < cps; ++n) { scores[i][n] = 0; for (int c = 0; c < numParams; ++c) { scores[i][n] += weights[c]*features[i][n][c]; } } } FileOutputStream outStream = new FileOutputStream(outputFileName, false); // false: don't append OutputStreamWriter outStreamWriter = new OutputStreamWriter(outStream, "utf8"); BufferedWriter outFile = new BufferedWriter(outStreamWriter); for (int i = 0; i < numSentences; ++i) { int[] indices = sort(scores,i); for (int n = 0; n < N; ++n) { String str = "" + i + " ||| " + candidates[i][indices[n]] + " |||"; for (int c = 0; c < numParams; ++c) { str += " " + f3.format(features[i][indices[n]][c]); } str += " ||| " + f3.format(scores[i][indices[n]]); writeLine(str, outFile); } } outFile.close(); System.exit(0); } private static int[] sort(double[][] scores, int i) { int numCands = scores[i].length; int[] retA = new int[numCands]; double[] sc = new double[numCands]; for (int n = 0; n < numCands; ++n) { retA[n] = n; sc[n] = scores[i][n]; } for (int j = 0; j < numCands; ++j) { int best_k = j; double best_sc = sc[j]; for (int k = j+1; k < numCands; ++k) { if (sc[k] > best_sc) { best_k = k; best_sc = sc[k]; } } // switch j and best_k int temp_n = retA[best_k]; retA[best_k] = retA[j]; retA[j] = temp_n; double temp_sc = sc[best_k]; sc[best_k] = sc[j]; sc[j] = temp_sc; } return retA; } private static void sort(int[] keys, double[] vals, int start, int end) { if (end-start > 1) { int mid = (start+end)/2; sort(keys,vals,start,mid); sort(keys,vals,mid+1,end); } } private static int countLines(String fileName) throws Exception { BufferedReader inFile = new BufferedReader(new FileReader(fileName)); String line; int count = 0; do { line = inFile.readLine(); if (line != null) ++count; } while (line != null); inFile.close(); return count; } private static void writeLine(String line, BufferedWriter writer) throws IOException { writer.write(line, 0, line.length()); writer.newLine(); writer.flush(); } private static void println(Object obj) { System.out.println(obj); } private static void print(Object obj) { System.out.print(obj); } }