package edu.stanford.nlp.semparse.open.dataset;
import java.util.*;
import edu.stanford.nlp.semparse.open.dataset.entity.TargetEntity;
import edu.stanford.nlp.semparse.open.dataset.entity.TargetEntityPersonName;
import edu.stanford.nlp.semparse.open.model.candidate.Candidate;
import fig.basic.LogInfo;
/**
* A Dataset represents a data set, which has multiple Examples (data instances).
*
* The examples are divided into training and test.
*/
public class Dataset {
public final List<Example> trainExamples = new ArrayList<>();
public final List<Example> testExamples = new ArrayList<>();
public Dataset() {
// Do nothing
}
public Dataset(List<Example> train, List<Example> test) {
trainExamples.addAll(train);
testExamples.addAll(test);
}
public Dataset addTrainExample(Example ex) {
trainExamples.add(ex);
return this;
}
public Dataset addTestExample(Example ex) {
testExamples.add(ex);
return this;
}
public Dataset addFromDataset(Dataset that) {
this.trainExamples.addAll(that.trainExamples);
this.testExamples.addAll(that.testExamples);
return this;
}
public Dataset addTrainFromDataset(Dataset that) {
this.trainExamples.addAll(that.trainExamples);
this.trainExamples.addAll(that.testExamples);
return this;
}
public Dataset addTestFromDataset(Dataset that) {
this.testExamples.addAll(that.trainExamples);
this.testExamples.addAll(that.testExamples);
return this;
}
/**
* @return a new Dataset with the Examples shuffled up.
* The train/test ratio remain the same.
* The original Dataset is not modified.
*/
public Dataset getNewShuffledDataset() {
List<Example> allExamples = new ArrayList<>(trainExamples);
allExamples.addAll(testExamples);
Collections.shuffle(allExamples, new Random(42));
List<Example> newTrain = allExamples.subList(0, trainExamples.size());
List<Example> newTest = allExamples.subList(trainExamples.size(), allExamples.size());
return new Dataset(newTrain, newTest);
}
/**
* @return a new Dataset with the specified train/test ratio.
*/
public Dataset getNewSplitDataset(double trainRatio) {
List<Example> allExamples = new ArrayList<>(trainExamples);
allExamples.addAll(testExamples);
Collections.shuffle(allExamples, new Random(42));
int trainEndIndex = (int) (allExamples.size() * trainRatio);
List<Example> newTrain = allExamples.subList(0, trainEndIndex);
List<Example> newTest = allExamples.subList(trainEndIndex, allExamples.size());
return new Dataset(newTrain, newTest);
}
// ============================================================
// Caching rewards
// ============================================================
public void cacheRewards() {
List<Example> uncached = new ArrayList<>();
for (Example ex : trainExamples)
if (!ex.expectedAnswer.frozenReward) uncached.add(ex);
for (Example ex : testExamples)
if (!ex.expectedAnswer.frozenReward) uncached.add(ex);
if (uncached.isEmpty()) return;
LogInfo.begin_track("Cache rewards ...");
for (Example ex : uncached) {
LogInfo.begin_track("Computing rewards for example %s ...", ex);
for (Candidate candidate : ex.candidates) {
ex.expectedAnswer.reward(candidate);
}
ex.expectedAnswer.frozenReward = true;
LogInfo.end_track();
}
LogInfo.end_track();
}
// ============================================================
// Shorthands for creating datasets.
// ============================================================
public Example E(String phrase, ExpectedAnswer expectedAnswer) {
return E(phrase, expectedAnswer, true);
}
public Example E(String phrase, ExpectedAnswer expectedAnswer, boolean isTrain) {
Example ex = new Example(phrase, expectedAnswer);
if (isTrain)
addTrainExample(ex);
else
addTestExample(ex);
return ex;
}
public ExpectedAnswer L(String... items) {
return L(false, items);
}
public ExpectedAnswer L(boolean exact, String... items) {
return new ExpectedAnswerInjectiveMatch(items);
}
public ExpectedAnswer LN(String... items) {
return LN(false, items);
}
public ExpectedAnswer LN(boolean exact, String... items) {
TargetEntity[] targetEntities = new TargetEntity[items.length];
for (int i = 0; i < items.length; i++) targetEntities[i] = N(items[i]);
return new ExpectedAnswerInjectiveMatch(items);
}
public TargetEntityPersonName N(String full) {
String[] parts = full.split(" ");
if (parts.length == 2)
return new TargetEntityPersonName(parts[0], parts[1]);
else if (parts.length == 3)
return new TargetEntityPersonName(parts[0], parts[1], parts[2]);
throw new RuntimeException("N(...) requires two or three words.");
}
}