package edu.stanford.nlp.trees; import java.io.BufferedWriter; import java.io.FileWriter; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Properties; import java.util.Random; import edu.stanford.nlp.util.logging.Redwood; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import edu.stanford.nlp.util.ArgumentParser; import edu.stanford.nlp.util.StringUtils; /** * Given a list of trees, splits the trees into three separate files. * <br> * The program uses a random seed to divide the trees. If the input * dataset is later extended, the same seed can be used and trees * which did not change position in the data set will be put in the * same division. * <br> * Example command line: * <code>java edu.stanford.nlp.trees.SplitTrainingSet -input foo.mrg -output bar.mrg -seed 1000</code> */ public class SplitTrainingSet { private static Redwood.RedwoodChannels logger = Redwood.channels(SplitTrainingSet.class); @ArgumentParser.Option(name="input", gloss="The file to use as input.", required=true) private static String INPUT = null; @ArgumentParser.Option(name="output", gloss="Where to send the splits.", required=true) private static String OUTPUT = null; @ArgumentParser.Option(name="split_names", gloss="Divisions to use for the output") private static String[] SPLIT_NAMES = { "train", "dev", "test" }; @ArgumentParser.Option(name="split_weights", gloss="Portions to use for the divisions") private static Double[] SPLIT_WEIGHTS = { 0.7, 0.15, 0.15 }; @ArgumentParser.Option(name="seed", gloss="Random seed to use") private static long SEED = 0L; public static int weightedIndex(List<Double> weights, Random random) { double offset = random.nextDouble(); int index = 0; for (Double weight : weights) { offset = offset - weight; if (offset < 0.0) { return index; } index = index + 1; } return weights.size() - 1; } public static void main(String[] args) throws IOException { // Parse the arguments Properties props = StringUtils.argsToProperties(args); ArgumentParser.fillOptions(new Class[]{ArgumentParser.class, SplitTrainingSet.class}, props); if (SPLIT_NAMES.length != SPLIT_WEIGHTS.length) { throw new IllegalArgumentException("Name and weight arrays must be of the same length"); } double totalWeight = 0.0; for (Double weight : SPLIT_WEIGHTS) { totalWeight += weight; if (weight < 0.0) { throw new IllegalArgumentException("Split weights cannot be negative"); } } if (totalWeight <= 0.0) { throw new IllegalArgumentException("Split weights must total to a positive weight"); } List<Double> splitWeights = new ArrayList<>(); for (Double weight : SPLIT_WEIGHTS) { splitWeights.add(weight / totalWeight); } logger.info("Splitting into " + splitWeights.size() + " lists with weights " + splitWeights); if (SEED == 0L) { SEED = System.nanoTime(); logger.info("Random seed not set by options, using " + SEED); } Random random = new Random(SEED); List<List<Tree>> splits = new ArrayList<>(); for (Double d : splitWeights) { splits.add(new ArrayList<>()); } Treebank treebank = new MemoryTreebank(in -> new PennTreeReader(in)); treebank.loadPath(INPUT); logger.info("Splitting " + treebank.size() + " trees"); for (Tree tree : treebank) { int index = weightedIndex(splitWeights, random); splits.get(index).add(tree); } for (int i = 0; i < splits.size(); ++i) { String filename = OUTPUT + "." + SPLIT_NAMES[i]; List<Tree> split = splits.get(i); logger.info("Writing " + split.size() + " trees to " + filename); FileWriter fout = new FileWriter(filename); BufferedWriter bout = new BufferedWriter(fout); for (Tree tree : split) { bout.write(tree.toString()); bout.newLine(); } bout.close(); fout.close(); } } }