package edu.stanford.nlp.semparse.open.dataset.library; import java.io.*; import java.nio.charset.Charset; import java.nio.file.*; import java.util.*; import com.fasterxml.jackson.databind.ObjectMapper; import edu.stanford.nlp.semparse.open.dataset.CriteriaExactMatch; import edu.stanford.nlp.semparse.open.dataset.CriteriaGeneralWeb; import edu.stanford.nlp.semparse.open.dataset.Dataset; import edu.stanford.nlp.semparse.open.dataset.Example; import edu.stanford.nlp.semparse.open.dataset.ExampleCached; import edu.stanford.nlp.semparse.open.dataset.ExpectedAnswer; import edu.stanford.nlp.semparse.open.dataset.ExpectedAnswerCriteriaMatch; import edu.stanford.nlp.semparse.open.dataset.ExpectedAnswerInjectiveMatch; import edu.stanford.nlp.semparse.open.dataset.entity.TargetEntity; import edu.stanford.nlp.semparse.open.dataset.entity.TargetEntityNearMatch; import edu.stanford.nlp.semparse.open.dataset.entity.TargetEntitySubstring; import edu.stanford.nlp.semparse.open.dataset.library.JSONDataset.JSONDatasetDatum; import fig.basic.LogInfo; import fig.basic.Option; /** * Read a dataset from JSON file and create a Dataset instance. * * See JSONDataset for file format. */ public class JSONDatasetReader { public static class Options { @Option(gloss = "Fraction of examples to use for training (from the start)") public double trainFrac = 0.8; @Option(gloss = "Fraction of examples to use for testing (from the end)") public double testFrac = 0.2; @Option public boolean fuzzyStringMatching = true; @Option public boolean zeroOneLoss = true; @Option public boolean shuffleDataset = false; @Option public long shuffleDatasetSeed = 42; } public static Options opts = new Options(); public final String family, name; public JSONDatasetReader(String family, String name) { this.family = family; this.name = name; } // ============================================================ // Get Dataset object // ============================================================ public Dataset getDataset() throws IOException { return getDataset(this.name); } public Dataset getDataset(String name) throws IOException { if (name == null) throw new RuntimeException("No dataset specified."); // Separate train and test data: [trainData]@[testData] if (name.contains("@")) { Dataset dataset = new Dataset(); String[] parts = name.split("@"); if (parts.length != 2) { LogInfo.fails("train@test syntax needs 2 datasets; got %d", parts.length); } dataset.addTrainFromDataset(getDataset(parts[0])); dataset.addTestFromDataset(getDataset(parts[1])); return dataset; } // Combine datasets: [data1]+[data2] if (name.contains("+")) { Dataset dataset = new Dataset(); for (String subName : name.split("[+]")) { Dataset subDataset = getDataset(subName); dataset.addFromDataset(subDataset); } return dataset; } // Single dataset Path path = Paths.get("datasets", family, name + ".json"); List<Example> examples = new ArrayList<>(); try (BufferedReader reader = Files.newBufferedReader(path, Charset.forName("UTF-8"))) { LogInfo.begin_track("Reading dataset from %s", path); // Read the JSON file ObjectMapper mapper = new ObjectMapper(); JSONDataset jsonDataset = mapper.readValue(reader, JSONDataset.class); LogInfo.log(jsonDataset.options); boolean firstTime = true; // Use to log the information only once for (JSONDatasetDatum datum : jsonDataset.data) { // Create the example ExpectedAnswer expectedAnswer; if (opts.zeroOneLoss) { expectedAnswer = getZeroOneLossExpectedAnswer(datum, jsonDataset, firstTime); } else { expectedAnswer = getIRExpectedAnswer(datum, jsonDataset, firstTime); } // Add the created example to the example list. if (jsonDataset.options.useHashcode) { examples.add(new ExampleCached(datum.query, jsonDataset.options.cacheDirectory, datum.hashcode, datum.url, expectedAnswer)); } else { examples.add(new Example(datum.query, expectedAnswer)); } firstTime = false; } LogInfo.end_track(); } // First trainFrac are training, last testFrac are test // Note that the examples can overlap and also don't have to cover all the examples. int trainEnd = (int)(examples.size() * opts.trainFrac); int testStart = (int)(examples.size() * (1 - opts.testFrac)); // Create the data set if (opts.shuffleDataset) { Collections.shuffle(examples, new Random(opts.shuffleDatasetSeed)); } Dataset dataset = new Dataset(); for (int i = 0; i < examples.size(); i++) { Example ex = examples.get(i); if (i < trainEnd) dataset.addTrainExample(ex); if (i >= testStart) dataset.addTestExample(ex); } return dataset; } // ============================================================ // Helper Methods // ============================================================ private ExpectedAnswer getZeroOneLossExpectedAnswer(JSONDatasetDatum datum, JSONDataset jsonDataset, boolean verbose) { if (jsonDataset.options.detailed) { // Use criteria matching (0-1 loss / exact match on first, second, and last entities) if (verbose) { LogInfo.log("Using 0-1 loss (must match first, second, and last entities to get reward = 1)"); } return new ExpectedAnswerCriteriaMatch(new CriteriaGeneralWeb(datum)); } else { // Use exact matching (0-1 loss / exact match on all entities) if (verbose) LogInfo.log("Using 0-1 loss (must match all entities to get reward = 1)"); List<TargetEntity> targetEntities = new ArrayList<>(); for (String entity : datum.entities) targetEntities.add(getTargetEntity(entity)); return new ExpectedAnswerCriteriaMatch(new CriteriaExactMatch(targetEntities)); } } private ExpectedAnswer getIRExpectedAnswer(JSONDatasetDatum datum, JSONDataset jsonDataset, boolean verbose) { // Use IR score-based matching (e.g. F1 > 80) if (verbose) LogInfo.logs("Using IR-based loss (must have %s >= %f to get positive reward)", ExpectedAnswerInjectiveMatch.opts.irCriterion, ExpectedAnswerInjectiveMatch.opts.irThreshold); // Convert each entity string to TargetEntity List<TargetEntity> targetEntities = new ArrayList<>(); for (String entity : datum.entities) targetEntities.add(getTargetEntity(entity)); return new ExpectedAnswerInjectiveMatch(targetEntities); } private TargetEntity getTargetEntity(String entity) { return opts.fuzzyStringMatching ? new TargetEntityNearMatch(entity) : new TargetEntitySubstring(entity); } }