/* This file is part of the Joshua Machine Translation System. * * Joshua is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 * of the License, or (at your option) any later version. * * This library is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this library; if not, write to the Free * Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, * MA 02111-1307 USA */ package joshua.zmert; import java.io.*; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Semaphore; public class TER extends EvaluationMetric { private boolean caseSensitive; private boolean withPunctuation; private int beamWidth; private int maxShiftDist; private String tercomJarFileName; private int numScoringThreads; public TER(String[] Metric_options) { // M_o[0]: case sensitivity, case/nocase // M_o[1]: with-punctuation, punc/nopunc // M_o[2]: beam width, positive integer // M_o[3]: maximum shift distance, positive integer // M_o[4]: filename of tercom jar file // M_o[5]: number of threads to use for TER scoring (= number of tercom processes launched) // for 0-3, default values in tercom-0.7.25 are: nocase, punc, 20, 50 if (Metric_options[0].equals("case")) { caseSensitive = true; } else if (Metric_options[0].equals("nocase")) { caseSensitive = false; } else { System.out.println("Unknown case sensitivity string " + Metric_options[0] + "."); System.out.println("Should be one of case or nocase."); System.exit(1); } if (Metric_options[1].equals("punc")) { withPunctuation = true; } else if (Metric_options[1].equals("nopunc")) { withPunctuation = false; } else { System.out.println("Unknown with-punctuation string " + Metric_options[1] + "."); System.out.println("Should be one of punc or nopunc."); System.exit(1); } beamWidth = Integer.parseInt(Metric_options[2]); if (beamWidth < 1) { System.out.println("Beam width must be positive"); System.exit(1); } maxShiftDist = Integer.parseInt(Metric_options[3]); if (maxShiftDist < 1) { System.out.println("Maximum shift distance must be positive"); System.exit(1); } tercomJarFileName = Metric_options[4]; if (tercomJarFileName == null || tercomJarFileName.equals("")) { System.out.println("Problem processing tercom's jar filename"); System.exit(1); } else { File checker = new File(tercomJarFileName); if (!checker.exists()) { System.out.println("Could not find tercom jar file " + tercomJarFileName); System.out.println("(Please make sure you use the full path in the filename)"); System.exit(1); } } numScoringThreads = Integer.parseInt(Metric_options[5]); if (numScoringThreads < 1) { System.out.println("Number of TER scoring threads must be positive"); System.exit(1); } TercomRunner.set_TercomParams(caseSensitive, withPunctuation, beamWidth, maxShiftDist, tercomJarFileName); initialize(); // set the data members of the metric } protected void initialize() { metricName = "TER"; toBeMinimized = true; suffStatsCount = 2; } public double bestPossibleScore() { return 0.0; } public double worstPossibleScore() { return (+1.0 / 0.0); } public int[] suffStats(String cand_str, int i) { // this method should never be used when the metric is TER, // because TER.java overrides createSuffStatsFile below, // which is the only method that calls suffStats(String,int). return null; } public int[][] suffStats(String[] cand_strings, int[] cand_indices) { // calculate sufficient statistics for each sentence in an arbitrary set of candidates int candCount = cand_strings.length; if (cand_indices.length != candCount) { System.out.println("Array lengths mismatch in suffStats(String[],int[]); returning null."); return null; } int[][] stats = new int[candCount][suffStatsCount]; try { // 1) Create input files for tercom // 1a) Create hypothesis file FileOutputStream outStream = new FileOutputStream("hyp.txt.TER", false); // false: don't append OutputStreamWriter outStreamWriter = new OutputStreamWriter(outStream, "utf8"); BufferedWriter outFile = new BufferedWriter(outStreamWriter); for (int d = 0; d < candCount; ++d) { writeLine(cand_strings[d] + " (ID" + d + ")",outFile); } outFile.close(); // 1b) Create reference file outStream = new FileOutputStream("ref.txt.TER", false); // false: don't append outStreamWriter = new OutputStreamWriter(outStream, "utf8"); outFile = new BufferedWriter(outStreamWriter); for (int d = 0; d < candCount; ++d) { for (int r = 0; r < refsPerSen; ++r) { writeLine(refSentences[cand_indices[d]][r] + " (ID" + d + ")",outFile); } } outFile.close(); // 2) Launch tercom as an external process runTercom("ref.txt.TER", "hyp.txt.TER", "TER_out", 500); // 3) Read SS from output file produced by tercom.7.25.jar BufferedReader inFile = new BufferedReader(new FileReader("TER_out.ter")); String line = ""; line = inFile.readLine(); // skip hyp line line = inFile.readLine(); // skip ref line for (int d = 0; d < candCount; ++d) { line = inFile.readLine(); // read info String[] strA = line.split("\\s+"); stats[d][0] = (int)Double.parseDouble(strA[1]); stats[d][1] = (int)Double.parseDouble(strA[2]); } // 4) Delete TER files File fd; fd = new File("hyp.txt.TER"); if (fd.exists()) fd.delete(); fd = new File("ref.txt.TER"); if (fd.exists()) fd.delete(); fd = new File("TER_out.ter"); if (fd.exists()) fd.delete(); } catch (IOException e) { System.err.println("IOException in TER.suffStats(String[],int[]): " + e.getMessage()); System.exit(99902); } return stats; } public void createSuffStatsFile(String cand_strings_fileName, String cand_indices_fileName, String outputFileName, int maxBatchSize) { try { int batchCount = 0; FileInputStream inStream_cands = new FileInputStream(cand_strings_fileName); BufferedReader inFile_cands = new BufferedReader(new InputStreamReader(inStream_cands, "utf8")); FileInputStream inStream_indices = new FileInputStream(cand_indices_fileName); BufferedReader inFile_indices = new BufferedReader(new InputStreamReader(inStream_indices, "utf8")); while (true) { ++batchCount; int readCount = createTercomHypFile(inFile_cands, tmpDirPrefix+"hyp.txt.TER.batch"+batchCount, 10000); createTercomRefFile(inFile_indices, tmpDirPrefix+"ref.txt.TER.batch"+batchCount, 10000); if (readCount == 0) { --batchCount; break; } else if (readCount < 10000) { break; } } // score the batchCount batches of candidates, in parallel, across numThreads threads ExecutorService pool = Executors.newFixedThreadPool(numScoringThreads); Semaphore blocker = new Semaphore(0); for (int b = 1; b <= batchCount; ++b) { pool.execute(new TercomRunner(blocker, tmpDirPrefix+"ref.txt.TER.batch"+b, tmpDirPrefix+"hyp.txt.TER.batch"+b, tmpDirPrefix+"TER_out.batch"+b, 500)); // Each thread scores the candidates, creating a tercom output file, // and then deletes the .hyp. and .ref. files, which are not needed // for other batches. } pool.shutdown(); try { blocker.acquire(batchCount); } catch(java.lang.InterruptedException e) { System.err.println("InterruptedException in TER.createSuffStatsFile(...): " + e.getMessage()); System.exit(99906); } PrintWriter outFile = new PrintWriter(outputFileName); for (int b = 1; b <= batchCount; ++b) { copySS(tmpDirPrefix+"TER_out.batch"+b+".ter", outFile); File fd; fd = new File(tmpDirPrefix+"TER_out.batch"+b+".ter"); if (fd.exists()) fd.delete(); // .hyp. and .ref. already deleted by individual threads } outFile.close(); } catch (IOException e) { System.err.println("IOException in TER.createSuffStatsFile(...): " + e.getMessage()); System.exit(99902); } } public int createTercomHypFile(BufferedReader inFile_cands, String hypFileName, int numCands) { // returns # lines read int readCount = 0; try { FileOutputStream outStream = new FileOutputStream(hypFileName, false); // false: don't append OutputStreamWriter outStreamWriter = new OutputStreamWriter(outStream, "utf8"); BufferedWriter outFile = new BufferedWriter(outStreamWriter); String line_cand = ""; if (numCands > 0) { for (int d = 0; d < numCands; ++d) { line_cand = inFile_cands.readLine(); if (line_cand != null) { ++readCount; writeLine(line_cand + " (ID" + d + ")",outFile); } else { break; } } } else { line_cand = inFile_cands.readLine(); int d = -1; while (line_cand != null) { ++readCount; ++d; writeLine(line_cand + " (ID" + d + ")",outFile); line_cand = inFile_cands.readLine(); } } outFile.close(); } catch (IOException e) { System.err.println("IOException in TER.createTercomHypFile(...): " + e.getMessage()); System.exit(99902); } return readCount; } public int createTercomRefFile(BufferedReader inFile_indices, String refFileName, int numIndices) { // returns # lines read int readCount = 0; try { FileOutputStream outStream = new FileOutputStream(refFileName, false); // false: don't append OutputStreamWriter outStreamWriter = new OutputStreamWriter(outStream, "utf8"); BufferedWriter outFile = new BufferedWriter(outStreamWriter); String line_index = ""; if (numIndices > 0) { for (int d = 0; d < numIndices; ++d) { line_index = inFile_indices.readLine(); if (line_index != null) { ++readCount; int index = Integer.parseInt(line_index); for (int r = 0; r < refsPerSen; ++r) { writeLine(refSentences[index][r] + " (ID" + d + ")",outFile); } } else { break; } } } else { line_index = inFile_indices.readLine(); int d = -1; while (line_index != null) { ++readCount; ++d; int index = Integer.parseInt(line_index); for (int r = 0; r < refsPerSen; ++r) { writeLine(refSentences[index][r] + " (ID" + d + ")",outFile); } line_index = inFile_indices.readLine(); } } outFile.close(); } catch (IOException e) { System.err.println("IOException in TER.createTercomRefFile(...): " + e.getMessage()); System.exit(99902); } return readCount; } public int runTercom(String refFileName, String hypFileName, String outFileNamePrefix, int memSize) { int exitValue = -1; try { String cmd_str = "java -Xmx" + memSize + "m -Dfile.encoding=utf8 -jar " + tercomJarFileName + " -r " + refFileName + " -h " + hypFileName + " -o ter -n " + outFileNamePrefix; cmd_str += " -b " + beamWidth; cmd_str += " -d " + maxShiftDist; if (caseSensitive) { cmd_str += " -s"; } if (!withPunctuation) { cmd_str += " -P"; } /* From tercom's README: -s case sensitivity, optional, default is insensitive -P no punctuations, default is with punctuations. */ Runtime rt = Runtime.getRuntime(); Process p = rt.exec(cmd_str); StreamGobbler errorGobbler = new StreamGobbler(p.getErrorStream(), 0); StreamGobbler outputGobbler = new StreamGobbler(p.getInputStream(), 0); errorGobbler.start(); outputGobbler.start(); exitValue = p.waitFor(); } catch (IOException e) { System.err.println("IOException in TER.runTercom(...): " + e.getMessage()); System.exit(99902); } catch (InterruptedException e) { System.err.println("InterruptedException in TER.runTercom(...): " + e.getMessage()); System.exit(99903); } return exitValue; } public void copySS(String inputFileName, PrintWriter outFile) { try { BufferedReader inFile = new BufferedReader(new FileReader(inputFileName)); String line = ""; line = inFile.readLine(); // skip hyp line line = inFile.readLine(); // skip ref line line = inFile.readLine(); // read info for first line while (line != null) { String[] strA = line.split("\\s+"); outFile.println((int)Double.parseDouble(strA[1]) + " " + (int)Double.parseDouble(strA[2])); line = inFile.readLine(); // read info for next line } } catch (IOException e) { System.err.println("IOException in TER.copySS(String,PrintWriter): " + e.getMessage()); System.exit(99902); } } public double score(int[] stats) { if (stats.length != suffStatsCount) { System.out.println("Mismatch between stats.length and suffStatsCount (" + stats.length + " vs. " + suffStatsCount + ") in TER.score(int[])"); System.exit(2); } double sc = 0.0; sc = stats[0]/(double)stats[1]; return sc; } public void printDetailedScore_fromStats(int[] stats, boolean oneLiner) { if (oneLiner) { System.out.println("TER = " + stats[0] + " / " + stats[1] + " = " + f4.format(score(stats))); } else { System.out.println("# edits = " + stats[0]); System.out.println("Reference length = " + stats[1]); System.out.println("TER = " + stats[0] + " / " + stats[1] + " = " + f4.format(score(stats))); } } private void writeLine(String line, BufferedWriter writer) throws IOException { writer.write(line, 0, line.length()); writer.newLine(); writer.flush(); } }