package edu.stanford.nlp.parser.tools;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import edu.stanford.nlp.util.logging.Redwood;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.SentenceUtils;
import edu.stanford.nlp.parser.common.ParserGrammar;
import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.parser.lexparser.TreeBinarizer;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.StringUtils;
/**
* Given a list of sentences, converts the sentences to trees and then
* relabels them using a list of new labels.
* <br>
* This tool processes the text using a given parser model, one
* sentence per line.
* <br>
* The labels file is expected to be a tab separated file. If there
* are multiple labels on a line, only the last one is used.
* <br>
* There are a few options for how to handle missing labels:
* FAIL, DEFAULT, KEEP_ORIGINAL
* <br>
* The argument for providing the labels is <code>-labels</code>
* <br>
* The argument for providing the sentences is <code>-sentences</code>
* <br>
* Alternatively, one can provide the flag <code>-useLabelKeys</code>
* to specify that the keys in the labels file should be treated as
* the sentences. Exactly one of <code>-useLabelKeys</code> or
* <code>-sentences</code> must be used.
* <br>
* Example command line:
* <br>
* java edu.stanford.nlp.parser.tools.ParseAndSetLabels -output foo.txt -sentences "C:\Users\JohnBauer\Documents\alphasense\dataset\sentences10.txt" -labels "C:\Users\JohnBauer\Documents\alphasense\dataset\phrases10.tsv" -parser edu/stanford/nlp/models/srparser/englishSR.ser.gz -tagger edu/stanford/nlp/models/pos-tagger/english-left3words/english-left3words-distsim.tagger -remapLabels 0=1,1=2,2=2,3=0,4=0
*/
public class ParseAndSetLabels {
private static Redwood.RedwoodChannels logger = Redwood.channels(ParseAndSetLabels.class);
public enum MissingLabels {
FAIL, DEFAULT, KEEP_ORIGINAL
}
public static void setLabels(Tree tree, Map<String, String> labelMap,
MissingLabels missing, String defaultLabel,
Set<String> unknowns) {
if (tree.isLeaf()) {
return;
}
String text = SentenceUtils.listToString(tree.yield());
String label = labelMap.get(text);
if (label != null) {
tree.label().setValue(label);
} else {
switch (missing) {
case FAIL:
throw new RuntimeException("No label for '" + text + "'");
case DEFAULT:
tree.label().setValue(defaultLabel);
unknowns.add(text);
break;
case KEEP_ORIGINAL:
// do nothing
break;
default:
throw new IllegalArgumentException("Unknown MissingLabels mode " + missing);
}
}
for (Tree child : tree.children()) {
setLabels(child, labelMap, missing, defaultLabel, unknowns);
}
}
public static Set<String> setLabels(List<Tree> trees, Map<String, String> labelMap,
MissingLabels missing, String defaultLabel) {
logger.info("Setting labels");
Set<String> unknowns = new HashSet<>();
for (Tree tree : trees) {
setLabels(tree, labelMap, missing, defaultLabel, unknowns);
}
return unknowns;
}
public static void writeTrees(List<Tree> trees, String outputFile) {
logger.info("Writing new trees to " + outputFile);
try {
BufferedWriter out = new BufferedWriter(new FileWriter(outputFile));
for (Tree tree : trees) {
out.write(tree.toString());
out.write("\n");
}
out.close();
} catch (IOException e) {
throw new RuntimeIOException(e);
}
}
public static Map<String, String> readLabelMap(String labelsFile, String separator, String remapLabels) {
logger.info("Reading labels from " + labelsFile);
Map<String, String> remap = Collections.emptyMap();
if (remapLabels != null) {
remap = StringUtils.mapStringToMap(remapLabels);
logger.info("Remapping labels using " + remap);
}
Map<String, String> labelMap = new HashMap<>();
for (String phrase : IOUtils.readLines(labelsFile)) {
String[] pieces = phrase.split(separator);
String label = pieces[pieces.length - 1];
if (remap.containsKey(label)) {
label = remap.get(label);
}
labelMap.put(pieces[0], label);
}
return labelMap;
}
public static List<String> readSentences(String sentencesFile) {
logger.info("Reading sentences from " + sentencesFile);
List<String> sentences = new ArrayList<>();
for (String sentence : IOUtils.readLines(sentencesFile)) {
sentences.add(sentence);
}
return sentences;
}
public static ParserGrammar loadParser(String parserFile, String taggerFile) {
if (taggerFile != null) {
return ParserGrammar.loadModel(parserFile, "-preTag", "-taggerSerializedFile", taggerFile);
} else {
return ParserGrammar.loadModel(parserFile);
}
}
public static List<Tree> parseSentences(List<String> sentences, ParserGrammar parser, TreeBinarizer binarizer) {
logger.info("Parsing sentences");
List<Tree> trees = new ArrayList<>();
for (String sentence : sentences) {
Tree tree = parser.parse(sentence);
if (binarizer != null) {
tree = binarizer.transformTree(tree);
}
trees.add(tree);
if (trees.size() % 1000 == 0) {
logger.info(" Parsed " + trees.size() + " trees");
}
}
return trees;
}
public static void main(String[] args) {
// TODO: rather than always rolling our own arg parser, we should
// find a library which does it for us nicely
String outputFile = null;
String sentencesFile = null;
String labelsFile = null;
String parserFile = LexicalizedParser.DEFAULT_PARSER_LOC;
String taggerFile = null;
MissingLabels missing = MissingLabels.DEFAULT;
String defaultLabel = "-1";
String separator = "\\t+";
String saveUnknownsFile = null;
String remapLabels = null;
int argIndex = 0;
boolean binarize = true;
boolean useLabelKeys = false;
while (argIndex < args.length) {
if (args[argIndex].equalsIgnoreCase("-output")) {
outputFile = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-sentences")) {
sentencesFile = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-labels")) {
labelsFile = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-parser")) {
parserFile = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-tagger")) {
taggerFile = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-missing")) {
missing = MissingLabels.valueOf(args[argIndex + 1]);
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-separator")) {
separator = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-default")) {
defaultLabel = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-saveUnknowns")) {
saveUnknownsFile = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-remapLabels")) {
remapLabels = args[argIndex + 1];
argIndex += 2;
} else if (args[argIndex].equalsIgnoreCase("-binarize")) {
binarize = true;
argIndex += 1;
} else if (args[argIndex].equalsIgnoreCase("-nobinarize")) {
binarize = false;
argIndex += 1;
} else if (args[argIndex].equalsIgnoreCase("-useLabelKeys")) {
useLabelKeys = true;
argIndex += 1;
} else if (args[argIndex].equalsIgnoreCase("-nouseLabelKeys")) {
useLabelKeys = false;
argIndex += 1;
} else {
throw new IllegalArgumentException("Unknown argument " + args[argIndex]);
}
}
if (outputFile == null) {
throw new IllegalArgumentException("-output is required");
}
if (sentencesFile == null && !useLabelKeys) {
throw new IllegalArgumentException("-sentences or -useLabelKeys is required");
}
if (sentencesFile != null && useLabelKeys) {
throw new IllegalArgumentException("Use only one of -sentences or -useLabelKeys");
}
if (labelsFile == null) {
throw new IllegalArgumentException("-labels is required");
}
ParserGrammar parser = loadParser(parserFile, taggerFile);
TreeBinarizer binarizer = null;
if (binarize) {
binarizer = TreeBinarizer.simpleTreeBinarizer(parser.getTLPParams().headFinder(), parser.treebankLanguagePack());
}
Map<String, String> labelMap = readLabelMap(labelsFile, separator, remapLabels);
List<String> sentences;
if (sentencesFile != null) {
sentences = readSentences(sentencesFile);
} else {
sentences = new ArrayList<String>(labelMap.keySet());
}
List<Tree> trees = parseSentences(sentences, parser, binarizer);
Set<String> unknowns = setLabels(trees, labelMap, missing, defaultLabel);
writeTrees(trees, outputFile);
}
}