package edu.stanford.nlp.semparse.open;
import java.util.*;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import edu.stanford.nlp.semparse.open.core.AllOptions;
import edu.stanford.nlp.semparse.open.core.InteractiveDemo;
import edu.stanford.nlp.semparse.open.core.OpenSemanticParser;
import edu.stanford.nlp.semparse.open.core.ParallelizedTrainer;
import edu.stanford.nlp.semparse.open.core.eval.Evaluator;
import edu.stanford.nlp.semparse.open.core.eval.EvaluatorStatistics;
import edu.stanford.nlp.semparse.open.core.eval.IterativeTester;
import edu.stanford.nlp.semparse.open.dataset.Dataset;
import edu.stanford.nlp.semparse.open.dataset.library.DatasetLibrary;
import edu.stanford.nlp.semparse.open.util.Parallelizer;
import fig.basic.LogInfo;
import fig.basic.Option;
import fig.exec.Execution;
/**
* Main class for training and testing models.
*
* <h1>Development Mode</h1>
* <p><blockquote><code>
* java edu.stanford.nlp.semparse.open.Main -dataset FAMILY.NAME [-trainFrac 0.8] [-testFrac 0.2] [-saveModel FILENAME] [-folds 1]
* </blockquote></code></p>
* <p>Specify 1 dataset to be divided into train + test based on the specified ratio (trainFrac & testFrac).
* <p>To use the whole dataset as the training data, use options trainFrac = 1 and testFrac = 0.
* <p>The model can be saved by specifying the saveModel option.
* <p>To run on multiple random splits in parallel, use the 'folds' option. (Note: saveModel cannot be used when folds > 1)
*
* <h1>Train & Test Mode</h1>
* <p><blockquote><code>
* java edu.stanford.nlp.semparse.open.Main -dataset FAMILY.TRAIN_NAME@TEST_NAME [-saveModel FILENAME]
* </blockquote></code></p>
* <p>Specify 2 datasets from the same family (one for train, one for test) separated by the "@" sign.
* <p>The model can be saved by specifying the saveModel option.
*
* <h1>Test Only Mode</h1>
* <p><blockquote><code>
* java edu.stanford.nlp.semparse.open.Main -loadModel FILENAME -dataset FAMILY.NAME
* </blockquote></code></p>
* <p>Load model from file + Test on the specified dataset.
*
* <h1>Interactive Mode</h1>
* <p><blockquote><code>
* java edu.stanford.nlp.semparse.open.Main -loadModel FILENAME
* </blockquote></code></p>
* <p>Load model from file + Start interactive mode where the desired query and web page can be entered directly.
*
* <h1>Experiment Mode</h1>
* <p><blockquote><code>
* java edu.stanford.nlp.semparse.open.experiment.Experiments -experiment NAME
* </blockquote></code></p>
* See {@link edu.stanford.nlp.semparse.open.experiment.Experiments Experiments}.
*
*/
public class Main implements Runnable {
public static class Options {
@Option(gloss="dataset name (format = family.name)") public String dataset = null;
@Option(gloss="filename for saving a trained model") public String saveModel = null;
@Option(gloss="filename for loading a trained model") public String loadModel = null;
@Option(gloss="number of folds") public int folds = 1;
}
public static Options opts = new Options();
public static void main(String args[]) {
Execution.run(args, new Main(), AllOptions.getOptionsParser());
}
// ============================================================
// Overall run script
// ============================================================
public void run() {
if (opts.folds > 1 && (opts.saveModel != null || opts.loadModel != null))
LogInfo.fail("Cannot save or load a model with folds > 1");
if (opts.loadModel != null) {
loadAndTestMode();
} else {
trainAndTestMode();
}
}
public void loadAndTestMode() {
loadAndTestMode(DatasetLibrary.getDataset(opts.dataset));
}
public void loadAndTestMode(Dataset dataset) {
OpenSemanticParser parser = AllOptions.loadModel(opts.loadModel);
if (dataset == null) {
// Interactive demo mode
new InteractiveDemo(parser).run();
} else {
// Test on the specified data set
new OpenSemanticParser().preTrain(dataset);
testCombined(parser, dataset);
}
OpenSemanticParser.cleanUp();
}
public void trainAndTestMode() {
trainAndTestMode(DatasetLibrary.getDataset(opts.dataset));
}
public void trainAndTestMode(Dataset dataset) {
// Train + Test (possibly many folds)
if (dataset == null)
LogInfo.fail("Must specify either a dataset to train on or a model to load.");
Execution.putOutput("numTrainExamples", dataset.trainExamples.size());
Execution.putOutput("numTestExamples", dataset.testExamples.size());
OpenSemanticParser.init();
new OpenSemanticParser().preTrain(dataset);
dataset.cacheRewards();
List<IterativeTester> iterativeTesters = (opts.folds > 1) ? runParallel(dataset) : runSingle(dataset);
OpenSemanticParser.cleanUp();
summarize(iterativeTesters);
}
// ============================================================
// Parallelization
// ============================================================
private List<IterativeTester> runSingle(Dataset dataset) {
List<IterativeTester> iterativeTesters = new ArrayList<>();
iterativeTesters.add(trainAndTest(dataset).getIterativeTester());
return iterativeTesters;
}
private List<IterativeTester> runParallel(Dataset dataset) {
// Shuffle dataset
List<ParallelizedTrainer> tasks = new ArrayList<>();
Dataset shuffled = dataset;
for (int i = 0; i < opts.folds; i++) {
tasks.add(new ParallelizedTrainer(shuffled, i != 0));
shuffled = shuffled.getNewShuffledDataset();
}
// Turn off logging temporarily
int oldLogVerbosity = OpenSemanticParser.opts.logVerbosity;
OpenSemanticParser.opts.logVerbosity = 0;
// Train in parallel
List<Future<OpenSemanticParser>> parsers = Parallelizer.runAndReturnStuff(tasks);
OpenSemanticParser.opts.logVerbosity = oldLogVerbosity;
// Accumulate OpenSemanticParser and test on the first random split
List<IterativeTester> iterativeTesters = new ArrayList<>();
try {
for (int i = 0; i < opts.folds; i++)
iterativeTesters.add(parsers.get(i).get().getIterativeTester());
test(parsers.get(0).get(), dataset);
} catch (ExecutionException | InterruptedException e) {
LogInfo.fail(e);
}
return iterativeTesters;
}
// ============================================================
// Train / Test
// ============================================================
private OpenSemanticParser train(Dataset dataset) {
OpenSemanticParser parser = new OpenSemanticParser();
parser.train(dataset);
if (opts.saveModel != null)
AllOptions.saveModel(opts.saveModel, parser);
return parser;
}
/**
* Test on both training set and test set
*/
private OpenSemanticParser test(OpenSemanticParser parser, Dataset dataset) {
Evaluator trainEvaluator = parser.test(dataset.trainExamples, "TRANING SET");
Evaluator testEvaluator = parser.test(dataset.testExamples, "TEST SET");
LogInfo.begin_track("### Error Analysis ###");
trainEvaluator.printDetails();
testEvaluator.printDetails();
LogInfo.end_track();
LogInfo.begin_track("### Summary ###");
trainEvaluator.printScores();
testEvaluator.printScores();
trainEvaluator.putOutput("train");
testEvaluator.putOutput("test");
LogInfo.end_track();
return parser;
}
/**
* Combine everything into the "test" dataset and test on it
*/
private OpenSemanticParser testCombined(OpenSemanticParser parser, Dataset dataset) {
dataset = new Dataset().addTestFromDataset(dataset);
Evaluator testEvaluator = parser.test(dataset.testExamples, "TEST SET");
LogInfo.begin_track("### Error Analysis ###");
testEvaluator.printDetails();
LogInfo.end_track();
LogInfo.begin_track("### Summary ###");
testEvaluator.printScores();
testEvaluator.putOutput("test");
LogInfo.end_track();
return parser;
}
private OpenSemanticParser trainAndTest(Dataset dataset) {
return test(train(dataset), dataset);
}
private void summarize(List<IterativeTester> iterativeTesters) {
for (int i = 0; i < opts.folds; i++)
iterativeTesters.get(i).summarize();
if (opts.folds > 1) {
List<EvaluatorStatistics> trainStats = new ArrayList<>(), testStats = new ArrayList<>();
for (IterativeTester tester : iterativeTesters) {
trainStats.add(tester.getLastTrainStat());
testStats.add(tester.getLastTestStat());
}
EvaluatorStatistics.logAverage(trainStats, "train");
EvaluatorStatistics.logAverage(testStats, "test");
}
}
}