package edu.stanford.nlp.international.arabic.pipeline; import edu.stanford.nlp.util.logging.Redwood; import java.io.*; import java.util.*; import edu.stanford.nlp.trees.treebank.ConfigParser; import edu.stanford.nlp.trees.*; import edu.stanford.nlp.trees.international.arabic.ATBTreeUtils; import edu.stanford.nlp.util.Generics; /** * Decimates a set of ATB parse trees. For every 10 parse trees, eight are added to the training set, and one * is added to each of the dev and test sets. * * @author Spence Green * */ public class DecimatedArabicDataset extends ATBArabicDataset { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(DecimatedArabicDataset.class); private boolean taggedOutput = false; private String wordTagDelim = "_"; @Override public void build() { //Set specific options for this dataset if(options.containsKey(ConfigParser.paramSplit)) { System.err.printf("%s: Ignoring split parameter for this dataset type\n", this.getClass().getName()); } else if(options.containsKey(ConfigParser.paramTagDelim)) { wordTagDelim = options.getProperty(ConfigParser.paramTagDelim); taggedOutput = true; } for(File path : pathsToData) { int prevSize = treebank.size(); treebank.loadPath(path,treeFileExtension,false); toStringBuffer.append(String.format(" Loaded %d trees from %s\n", treebank.size() - prevSize, path.getPath())); prevSize = treebank.size(); } ArabicTreeDecimatedNormalizer tv = new ArabicTreeDecimatedNormalizer(outFileName,makeFlatFile,taggedOutput); treebank.apply(tv); outputFileList.addAll(tv.getFilenames()); tv.closeOutputFiles(); } public class ArabicTreeDecimatedNormalizer extends ArabicRawTreeNormalizer { private int treesVisited = 0; private final String trainExtension = ".train"; private final String testExtension = ".test"; private final String devExtension = ".dev"; private final String flatExtension = ".flat"; private boolean makeFlatFile = false; private boolean taggedOutput = false; private Map<String,String> outFilenames; private Map<String,PrintWriter> outFiles; public ArabicTreeDecimatedNormalizer(String filePrefix, boolean makeFlat, boolean makeTagged) { super(null,null); makeFlatFile = makeFlat; taggedOutput = makeTagged; //Setup the decimation output files outFilenames = Generics.newHashMap(); outFilenames.put(trainExtension, filePrefix + trainExtension); outFilenames.put(testExtension, filePrefix + testExtension); outFilenames.put(devExtension, filePrefix + devExtension); if(makeFlatFile) { outFilenames.put(trainExtension + flatExtension,filePrefix + trainExtension + flatExtension); outFilenames.put(testExtension + flatExtension,filePrefix + testExtension + flatExtension); outFilenames.put(devExtension + flatExtension,filePrefix + devExtension + flatExtension); } setupOutputFiles(); } private void setupOutputFiles() { PrintWriter outfile = null; String curOutFileName = ""; try { outFiles = Generics.newHashMap(); for(String keyForFile : outFilenames.keySet()) { curOutFileName = outFilenames.get(keyForFile); if(!makeFlatFile && curOutFileName.contains(flatExtension)) continue; outfile = new PrintWriter(new BufferedWriter(new OutputStreamWriter(new FileOutputStream(curOutFileName),"UTF-8"))); outFiles.put(keyForFile, outfile); } } catch (UnsupportedEncodingException e) { System.err.printf("%s: Filesystem does not support UTF-8 output\n", this.getClass().getName()); e.printStackTrace(); } catch (FileNotFoundException e) { System.err.printf("%s: Could not open %s for writing\n", this.getClass().getName(), curOutFileName); } } public void closeOutputFiles() { for(String keyForFile : outFiles.keySet()) outFiles.get(keyForFile).close(); } public void visitTree(Tree t) { if(t == null || t.value().equals("X")) return; t = t.prune(nullFilter, new LabeledScoredTreeFactory()); //Do *not* strip traces here. The ArabicTreeReader will do that if needed for(Tree node : t) if(node.isPreTerminal()) processPreterminal(node); treesVisited++; String flatString = (makeFlatFile) ? ATBTreeUtils.flattenTree(t) : null; //Do the decimation if(treesVisited % 9 == 0) { write(t, outFiles.get(devExtension)); if(makeFlatFile) outFiles.get(devExtension + flatExtension).println(flatString); } else if(treesVisited % 10 == 0) { write(t, outFiles.get(testExtension)); if(makeFlatFile) outFiles.get(testExtension + flatExtension).println(flatString); } else { write(t, outFiles.get(trainExtension)); if(makeFlatFile) outFiles.get(trainExtension + flatExtension).println(flatString); } } private void write(Tree t, PrintWriter pw) { if(taggedOutput) pw.println(ATBTreeUtils.taggedStringFromTree(t, removeEscapeTokens, wordTagDelim)); else t.pennPrint(pw); } public List<String> getFilenames() { List<String> filenames = new ArrayList<>(); for(String keyForFile : outFilenames.keySet()) filenames.add(outFilenames.get(keyForFile)); return filenames; } } }