//
// Ensemble - runs a linear-interpolation of shift-reduce parsers
// Copyright (c) 2009-2010 The Board of Trustees of
// The Leland Stanford Junior University. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
//
// For more information, bug reports, fixes, contact:
// Mihai Surdeanu
// mihais AT stanford DOT edu
//
package edu.stanford.nlp.parser.ensemble;
import edu.stanford.cs.ra.arguments.Argument;
import edu.stanford.cs.ra.arguments.Arguments;
import edu.stanford.nlp.parser.ensemble.utils.Eisner;
import edu.stanford.nlp.parser.ensemble.utils.ProjectivizeCorpus;
import edu.stanford.nlp.parser.ensemble.utils.ReverseCorpus;
import edu.stanford.nlp.parser.ensemble.utils.Scorer;
import edu.stanford.nlp.parser.ensemble.utils.Scorer.Score;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.commons.io.FileUtils;
import org.maltparser.core.helper.SystemLogger;
public class Ensemble {
@Argument("Name of the ensemble model.")
@Argument.Default("ensemble")
@Argument.Switch("--modelName")
String modelName;
@Argument("Comma-separated list of base models to use in the ensemble.")
@Argument.Default("nivreeager-ltr,nivrestandard-ltr,nivrestandard-rtl")
@Argument.Switch("--baseModelNames")
String baseModelNames;
/**
* Parsed list of models from baseModelNames
*/
String[] baseModels;
@Argument("Feature model specification (use comma to separate feature model of each algorithm).")
@Argument.Default("<default>")
@Argument.Switch("--featureModelNames")
private String featureModelNames;
/**
* Parsed list of models from baseModelNames
*/
String[] featureModels;
@Argument("Location of the training corpus.")
@Argument.Switch("--trainCorpus")
String trainCorpus = null;
@Argument("Location of the evaluation corpus.")
@Argument.Switch("--testCorpus")
String testCorpus = null;
@Argument("Prefix output files generated during evaluation with this string (all output files are saved in workingDirectory).")
@Argument.Switch("--outputPrefix")
String outputPrefix = null;
@Argument("True if during training the ensemble should create one thread per base model")
@Argument.Default("false")
@Argument.Switch("--multiThreadTrain")
boolean multiThreadTrain;
@Argument("True if during evaluation the ensemble should create one thread per base model")
@Argument.Default("false")
@Argument.Switch("--multiThreadEval")
boolean multiThreadEval;
@Argument("The model files will be saved in this directory.")
@Argument.Default("/tmp")
@Argument.Switch("--modelDirectory")
String modelDirectory;
@Argument("Temporary files created during execution will be stored (and deleted on completion) in this directory.")
@Argument.Default("/tmp")
@Argument.Switch("--workingDirectory")
String workingDirectory;
@Argument("Training options for liblinear.")
@Argument.Default("-s_4_-e_0.1_-c_0.2_-B_1.0")
@Argument.Switch("--libLinearOptions")
String libLinearOptions;
@Argument("Log level: off|fatal|error|warn|info|debug")
@Argument.Default("info")
@Argument.Switch("--logLevel")
String logLevel;
@Argument("Liblinear log level: silent|error|all")
@Argument.Default("error")
@Argument.Switch("--libLinearLogLevel")
String libLinearLogLevel;
@Argument("Use this external program to train liblinear (should be more robust)")
@Argument.Default("")
@Argument.Switch("--libLinearTrain")
String libLinearTrain;
@Argument("Split base models based on this column.")
@Argument.Default("POSTAG")
@Argument.Switch("--dataSplitColumn")
String dataSplitColumn;
//
// Use this in combination with "-s Input[0]" for non covington models or with "-s Right[0]" for covington models
//
@Argument("Data split threshold for base models.")
@Argument.Default("100")
@Argument.Switch("--dataSplitThreshold")
Integer dataSplitThreshold;
@Argument("train|test")
@Argument.Default("test")
@Argument.Switch("--run")
String run;
@Argument("Reparsing algorithm: majority|attardi|eisner")
@Argument.Default("eisner")
@Argument.Switch("--reparseAlgorithm")
String reparseAlgorithm;
@Argument("Size of the thread pool, if multi-threaded processing is enabled.")
@Argument.Default("4")
@Argument.Switch("--threadCount")
private Integer threadCount;
/**
* Automatically set to true if any base model requires right-to-left
* processing
*/
private boolean rightToLeft;
/**
* Automatically set to true if any base model requires pseudo projective
* processing
*/
private boolean rtl_pseudo_projective;
private boolean ltr_pseudo_projective;
public static void main(String[] args) throws Exception {
Ensemble ensemble = new Ensemble(args);
ensemble.run();
}
public Ensemble(String[] args) {
Arguments.parse(args, this);
SystemLogger.instance().setSystemVerbosityLevel(logLevel);
//
// sanity checks
//
File md = new File(modelDirectory);
if (!md.exists()) {
throw new RuntimeException("ERROR: Model directory " + md.getAbsolutePath() + " does not exist!");
}
if (!md.isDirectory()) {
throw new RuntimeException("ERROR: Model directory " + md.getAbsolutePath() + " is not a directory!");
}
if (!md.canWrite()) {
throw new RuntimeException("ERROR: Must have write permission to model directory " + md.getAbsolutePath() + "!");
}
File wd = new File(workingDirectory);
if (!wd.exists()) {
throw new RuntimeException("ERROR: Working directory " + wd.getAbsolutePath() + " does not exist!");
}
if (!wd.isDirectory()) {
throw new RuntimeException("ERROR: Working directory " + wd.getAbsolutePath() + " is not a directory!");
}
if (!wd.canWrite()) {
throw new RuntimeException("ERROR: Must have write permission to working directory " + wd.getAbsolutePath() + "!");
}
baseModels = baseModelNames.split(",");
int len = baseModels.length;
featureModels = new String[len];
String[] models = featureModelNames.split(",");
for (int i = 0; i < len; i++) {
if (i < models.length) {
featureModels[i] = models[i];
} else {
featureModels[i] = "<default>";
}
}
rightToLeft = false;
rtl_pseudo_projective = false;
ltr_pseudo_projective = false;
for (String bm : baseModels) {
if (!BASE_MODELS.contains(bm)) {
throw new RuntimeException("Unknown base model: " + bm);
}
if (bm.lastIndexOf("rtl") != -1) {
rightToLeft = true;
if (bm.endsWith("+PP")) {
rtl_pseudo_projective = true;
}
} else if (bm.endsWith("+PP")) {
ltr_pseudo_projective = true;
}
}
if (!run.equals(Const.RUN_TEST) && !run.equals(Const.RUN_TRAIN)) {
throw new RuntimeException("Unknown run mode: " + run);
}
if (run.equalsIgnoreCase(Const.RUN_TRAIN)) {
if (trainCorpus == null) {
throw new RuntimeException("Training corpus must be specified if --run train!");
}
File f = new File(trainCorpus);
if (!f.exists()) {
throw new RuntimeException("ERROR: Training corpus " + f.getAbsolutePath() + " does not exist!");
}
if (!f.isFile()) {
throw new RuntimeException("ERROR: Training corpus " + f.getAbsolutePath() + " is not a file!");
}
if (!f.canRead()) {
throw new RuntimeException("ERROR: Must have read permission to training corpus " + f.getAbsolutePath() + "!");
}
SystemLogger.logger().info("Will run in TRAIN mode.\n");
}
if (run.equalsIgnoreCase(Const.RUN_TEST)) {
if (testCorpus == null) {
throw new RuntimeException("Test corpus must be specified if --run test!");
}
File f = new File(testCorpus);
if (!f.exists()) {
throw new RuntimeException("ERROR: Test corpus " + f.getAbsolutePath() + " does not exist!");
}
if (!f.isFile()) {
throw new RuntimeException("ERROR: Test corpus " + f.getAbsolutePath() + " is not a file!");
}
if (!f.canRead()) {
throw new RuntimeException("ERROR: Must have read permission to test corpus " + f.getAbsolutePath() + "!");
}
SystemLogger.logger().info("Will run in TEST mode.\n");
}
}
private static final Set<String> BASE_MODELS = new HashSet<String>(Arrays.asList(
"nivreeager-ltr", "nivrestandard-ltr", "covnonproj-ltr",
"nivreeager-ltr+PP", "nivrestandard-ltr+PP",
"nivreeager-rtl", "nivrestandard-rtl", "covnonproj-rtl",
"nivreeager-rtl+PP", "nivrestandard-rtl+PP"));
public void run() throws IOException {
List<Runnable> jobs = createJobs();
boolean multiThreaded = false;
if ((run.equalsIgnoreCase(Const.RUN_TRAIN) && multiThreadTrain)
|| (run.equalsIgnoreCase(Const.RUN_TEST) && multiThreadEval)) {
multiThreaded = true;
}
String file_name;
String phase_name;
// reverse the training corpus
if (run.equals(Const.RUN_TRAIN)) {
file_name = trainCorpus;
phase_name = "training";
} // reverse the testing corpus
else if (run.equals(Const.RUN_TEST)) {
file_name = testCorpus;
phase_name = "testing";
} else {
throw new RuntimeException("Unknown run mode: " + run);
}
if (rightToLeft) {
File f = new File(file_name);
File f1 = new File(workingDirectory + File.separator + f.getName());
f1.deleteOnExit();
FileUtils.copyFile(f, f1);
if (rtl_pseudo_projective && run.equals(Const.RUN_TRAIN)) {
String ppReversedFileName = workingDirectory + File.separator + f.getName() + ".pp";
try {
SystemLogger.logger().debug("Projectivise reversing " + phase_name + " corpus to " + ppReversedFileName + "\n");
String input = f.getName();
f = new File(ppReversedFileName);
String output = f.getName();
ProjectivizeCorpus.Projectivize(workingDirectory, input, output, "pp-reverse");
f.deleteOnExit();
f = new File("pp-reverse.mco");
f.deleteOnExit();
f = new File(ppReversedFileName);
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException("Error: cannot projectivize corpus");
}
}
String reversedFileName = workingDirectory + File.separator + f.getName() + ".reversed";
SystemLogger.logger().debug("Reversing " + phase_name + " corpus to " + reversedFileName + "\n");
ReverseCorpus.reverseCorpus(f.getAbsolutePath(), reversedFileName);
f = new File(reversedFileName);
f.deleteOnExit();
}
if (ltr_pseudo_projective && run.equals(Const.RUN_TRAIN)) {
File f = new File(file_name);
File f1 = new File(workingDirectory + File.separator + f.getName());
f1.deleteOnExit();
FileUtils.copyFile(f, f1);
String ppFileName = workingDirectory + File.separator + f.getName() + ".pp";
try {
SystemLogger.logger().debug("Projectivise " + phase_name + " corpus to " + ppFileName + "\n");
String input = f.getName();
f = new File(ppFileName);
String output = f.getName();
ProjectivizeCorpus.Projectivize(workingDirectory, input, output, "pp");
f.deleteOnExit();
f = new File("pp.mco");
f.deleteOnExit();
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException("Error: cannot projectivize corpus");
}
}
if (multiThreaded) {
ExecutorService threadPool = Executors.newFixedThreadPool(threadCount);
for (Runnable job : jobs) {
threadPool.execute(job);
}
threadPool.shutdown();
this.waitForThreads(jobs.size());
} else {
for (Runnable job : jobs) {
job.run();
}
}
// run the actual ensemble model
if (run.equalsIgnoreCase(Const.RUN_TEST)) {
String outFile = workingDirectory + File.separator + outputPrefix + "." + modelName + "-ensemble";
List<String> sysFiles = new ArrayList<String>();
for (String baseModel : baseModels) {
sysFiles.add((workingDirectory + File.separator + outputPrefix + "." + modelName + "-" + baseModel));
}
// generate the ensemble
Eisner.ensemble(testCorpus, sysFiles, outFile, reparseAlgorithm);
// score the ensemble
Score s = Scorer.evaluate(testCorpus, outFile);
if (s != null) {
SystemLogger.logger().info(String.format("ensemble LAS: %.2f %d/%d\n", s.las, s.lcorrect, s.total));
SystemLogger.logger().info(String.format("ensemble UAS: %.2f %d/%d\n", s.uas, s.ucorrect, s.total));
}
SystemLogger.logger().info("Ensemble output saved as: " + outFile + "\n");
}
SystemLogger.logger().info("DONE.\n");
}
private synchronized void waitForThreads(int count) {
while (count > 0) {
try {
this.wait();
count--;
SystemLogger.logger().info("One thread finished. " + count + " still going.\n");
} catch (InterruptedException e) {
SystemLogger.logger().info("Main thread interrupted!\n");
break;
}
}
SystemLogger.logger().info("All threads finished.\n");
}
public synchronized void threadFinished() {
this.notify();
}
private List<Runnable> createJobs() {
List<Runnable> jobs = new ArrayList<Runnable>();
if (run.equalsIgnoreCase(Const.RUN_TRAIN)) {
for (int i = 0; i < baseModels.length; i++) {
jobs.add(new RunnableTrainJob(this, i));
}
} else if (run.equalsIgnoreCase(Const.RUN_TEST)) {
for (int i = 0; i < baseModels.length; i++) {
jobs.add(new RunnableTestJob(this, i));
}
}
return jobs;
}
}