package edu.stanford.nlp.coref.docreader;
import java.io.BufferedReader;
import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Properties;
import java.util.Set;
import java.util.Stack;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import edu.stanford.nlp.coref.CorefCoreAnnotations;
import edu.stanford.nlp.coref.data.InputDoc;
import edu.stanford.nlp.coref.data.Mention;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreAnnotation;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreAnnotations.SentencesAnnotation;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.pipeline.Annotation;
import edu.stanford.nlp.pipeline.ChunkAnnotationUtils;
import edu.stanford.nlp.semgraph.SemanticGraph;
import edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations;
import edu.stanford.nlp.semgraph.SemanticGraphFactory;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.trees.GrammaticalStructure;
import edu.stanford.nlp.trees.HeadFinder;
import edu.stanford.nlp.trees.LabeledScoredTreeReaderFactory;
import edu.stanford.nlp.trees.ModCollinsHeadFinder;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeCoreAnnotations;
import edu.stanford.nlp.trees.TreeCoreAnnotations.TreeAnnotation;
import edu.stanford.nlp.trees.TreeNormalizer;
import edu.stanford.nlp.trees.Trees;
import edu.stanford.nlp.trees.international.pennchinese.ChineseGrammaticalStructure;
import edu.stanford.nlp.trees.international.pennchinese.ChineseSemanticHeadFinder;
import edu.stanford.nlp.util.AbstractIterator;
import edu.stanford.nlp.util.CollectionFactory;
import edu.stanford.nlp.util.CollectionValuedMap;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Filters;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.IntPair;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.logging.Redwood;
/**
* Read _conll file format from CoNLL2011. See http://conll.bbn.com/index.php/data.html.
*
* CoNLL2011 files are in /scr/nlp/data/conll-2011/v0/data/
* dev
* train
* Contains *_auto_conll files (auto generated) and _gold_conll (hand labelled), default reads _gold_conll
* There is also /scr/nlp/data/conll-2011/v0/conll.trial which has *.conll files (parse has _ at end)
*
* Column Type Description
* 1 Document ID This is a variation on the document filename
* 2 Part number Some files are divided into multiple parts numbered as 000, 001, 002, ... etc.
* 3 Word number
* 4 Word itself
* 5 Part-of-Speech
* 6 Parse bit This is the bracketed structure broken before the first open parenthesis in the parse, and the word/part-of-speech leaf replaced with a *. The full parse can be created by substituting the asterix with the "([pos] [word])" string (or leaf) and concatenating the items in the rows of that column.
* 7 Predicate lemma The predicate lemma is mentioned for the rows for which we have semantic role information. All other rows are marked with a "-"
* 8 Predicate Frameset ID This is the PropBank frameset ID of the predicate in Column 7.
* 9 Word sense This is the word sense of the word in Column 3.
* 10 Speaker/Author This is the speaker or author name where available. Mostly in Broadcast Conversation and Web Log data.
* 11 Named Entities These columns identifies the spans representing various named entities.
* 12:N Predicate Arguments There is one column each of predicate argument structure information for the predicate mentioned in Column 7.
* N Coreference Coreference chain information encoded in a parenthesis structure.
*
* @author Angel Chang, Heeyoung Lee
*/
public class CoNLLDocumentReader implements DocReader {
/** A logger for this class */
private static Redwood.RedwoodChannels log = Redwood.channels(CoNLLDocumentReader.class);
private static final int FIELD_LAST = -1;
private static final int FIELD_DOC_ID = 0;
private static final int FIELD_PART_NO = 1;
private static final int FIELD_WORD_NO = 2;
private static final int FIELD_WORD = 3;
private static final int FIELD_POS_TAG = 4;
private static final int FIELD_PARSE_BIT = 5;
// private static final int FIELD_PRED_LEMMA = 6;
// private static final int FIELD_PRED_FRAMESET_ID = 7;
// private static final int FIELD_WORD_SENSE = 8;
private static final int FIELD_SPEAKER_AUTHOR = 9;
private static final int FIELD_NER_TAG = 10;
// private static final int FIELD_PRED_ARGS = 11; // Predicate args follow...
private static final int FIELD_COREF = FIELD_LAST; // Last field
private static final int FIELDS_MIN = 12; // There should be at least 13 fields
private DocumentIterator docIterator;
// private String filepath;
protected final List<File> fileList;
private int curFileIndex;
private final Options options;
public static final Logger logger = Logger.getLogger(CoNLLDocumentReader.class.getName());
private static final HeadFinder chineseHeadFinder = new ChineseSemanticHeadFinder();
public CoNLLDocumentReader(String filepath)
{
this(filepath, new Options());
}
public CoNLLDocumentReader(String filepath, Options options)
{
// this.filepath = filepath;
if (filepath != null && new File(filepath).exists()) {
this.fileList = getFiles(filepath, options.filePattern);
} else {
this.fileList = Collections.EMPTY_LIST;
}
this.options = options;
if (options.sortFiles) {
Collections.sort(this.fileList);
}
curFileIndex = 0;
if (filepath != null && new File(filepath).exists() && options.printConLLLoadingMessage) {
logger.info("Reading " + fileList.size() + " CoNLL files from " + filepath);
}
}
private static List<File> getFiles(String filepath, Pattern filter)
{
Iterable<File> iter = IOUtils.iterFilesRecursive(new File(filepath), filter);
List<File> fileList = new ArrayList<>();
for (File f:iter) {
fileList.add(f);
}
Collections.sort(fileList);
return fileList;
}
public void reset() {
curFileIndex = 0;
if (docIterator != null) {
docIterator.close();
docIterator = null;
}
}
public CoNLLDocument getNextDocument()
{
try {
if (curFileIndex >= fileList.size()) return null; // DONE!
File curFile = fileList.get(curFileIndex);
if (docIterator == null) {
docIterator = new DocumentIterator(curFile.getAbsolutePath(), options);
}
while ( ! docIterator.hasNext()) {
Redwood.log("debug-docreader", "Processed " + docIterator.docCnt + " documents in " + curFile.getAbsolutePath());
docIterator.close();
curFileIndex++;
if (curFileIndex >= fileList.size()) {
return null; // DONE!
}
curFile = fileList.get(curFileIndex);
docIterator = new DocumentIterator(curFile.getAbsolutePath(), options);
}
CoNLLDocument next = docIterator.next();
Redwood.log("debug-docreader", "Reading document: " + next.getDocumentID()+" part: "+next.getPartNo());
return next;
} catch (IOException ex) {
throw new RuntimeIOException(ex);
}
}
public void close()
{
IOUtils.closeIgnoringExceptions(docIterator);
}
public static class NamedEntityAnnotation implements CoreAnnotation<CoreMap> {
public Class<CoreMap> getType() {
return CoreMap.class;
}
}
public static class CorefMentionAnnotation implements CoreAnnotation<CoreMap> {
public Class<CoreMap> getType() {
return CoreMap.class;
}
}
/** Flags **/
public static class Options {
public boolean printConLLLoadingMessage = true;
public boolean useCorefBIOESEncoding = false; // Marks Coref mentions with prefix
// B- begin, I- inside, E- end, S- single
public boolean annotateTokenCoref = true; // Annotate token with CorefAnnotation
// If token belongs to multiple clusters
// coref clusterid are separted by '|'
public boolean annotateTokenSpeaker = true; // Annotate token with SpeakerAnnotation
public boolean annotateTokenPos = true; // Annotate token with PartOfSpeechAnnotation
public boolean annotateTokenNer = true; // Annotate token with NamedEntityTagAnnotation
public boolean annotateTreeCoref = false; // Annotate tree with CorefMentionAnnotation
public boolean annotateTreeNer = false; // Annotate tree with NamedEntityAnnotation
public Locale lang = Locale.ENGLISH;
public String backgroundNerTag = "O"; // Background NER tag
protected String fileFilter;
protected Pattern filePattern;
protected boolean sortFiles;
public Options() {
this(".*_gold_conll$"); // _gold_conll or _auto_conll or .conll
}
public Options(String filter) {
fileFilter = filter;
filePattern = Pattern.compile(fileFilter);
}
public void setFilter(String filter) {
fileFilter = filter;
filePattern = Pattern.compile(fileFilter);
}
}
public static class CoNLLDocument implements Serializable {
private static final long serialVersionUID = 6287339385357914531L;
String documentIdPart;
public String documentID;
String partNo;
public String filename;
public List<List<String[]>> sentenceWordLists = new ArrayList<>();
Annotation annotation;
CollectionValuedMap<String,CoreMap> corefChainMap;
List<CoreMap> nerChunks;
public String getDocumentID() {
return documentID;
}
public void setDocumentID(String documentID) {
this.documentID = documentID;
}
public String getPartNo() {
return partNo;
}
public void setPartNo(String partNo) {
this.partNo = partNo;
}
public List<List<String[]>> getSentenceWordLists() {
return sentenceWordLists;
}
public void addSentence(List<String[]> sentence) {
this.sentenceWordLists.add(sentence);
}
public Annotation getAnnotation() {
return annotation;
}
public void setAnnotation(Annotation annotation) {
this.annotation = annotation;
}
public CollectionValuedMap<String,CoreMap> getCorefChainMap()
{
return corefChainMap;
}
}
private static String getField(String[] fields, int pos)
{
if (pos == FIELD_LAST) {
return fields[fields.length - 1];
} else {
return fields[pos];
}
}
private static String concatField(List<String[]> sentWords, int pos)
{
StringBuilder sb = new StringBuilder();
for (String[] fields:sentWords) {
if (sb.length() > 0) {
sb.append(' ');
}
sb.append(getField(fields, pos));
}
return sb.toString();
}
/** Helper iterator **/
private static class DocumentIterator extends AbstractIterator<CoNLLDocument> implements Closeable {
private static final Pattern delimiterPattern = Pattern.compile("\\s+");
private static final LabeledScoredTreeReaderFactory treeReaderFactory =
new LabeledScoredTreeReaderFactory((TreeNormalizer) null);
private final Options options;
// State
String filename;
BufferedReader br;
CoNLLDocument nextDoc;
int lineCnt = 0;
int docCnt = 0;
public DocumentIterator(String filename, Options options) throws IOException {
this.options = options;
this.filename = filename;
this.br = IOUtils.getBufferedFileReader(filename);
nextDoc = readNextDocument();
}
@Override
public boolean hasNext() {
return nextDoc != null;
}
@Override
public CoNLLDocument next() {
if (nextDoc == null) {
throw new NoSuchElementException("DocumentIterator exhausted.");
}
CoNLLDocument curDoc = nextDoc;
nextDoc = readNextDocument();
return curDoc;
}
private static final Pattern starPattern = Pattern.compile("\\*");
private static Tree wordsToParse(List<String[]> sentWords)
{
StringBuilder sb = new StringBuilder();
for (String[] fields:sentWords) {
if (sb.length() > 0) {
sb.append(' ');
}
String str = fields[FIELD_PARSE_BIT].replace("NOPARSE", "X");
String tagword = "(" + fields[FIELD_POS_TAG] + " " + fields[FIELD_WORD] + ")";
// Replace stars
int si = str.indexOf('*');
sb.append(str.substring(0, si));
sb.append(tagword);
sb.append(str.substring(si+1));
si = str.indexOf('*', si+1);
if (si >= 0) {
logger.warning(" Parse bit with multiple *: " + str);
}
}
String parseStr = sb.toString();
return Tree.valueOf(parseStr, treeReaderFactory);
}
private static List<Triple<Integer,Integer,String>> getCorefSpans(List<String[]> sentWords)
{
return getLabelledSpans(sentWords, FIELD_COREF, HYPHEN, true);
}
private static List<Triple<Integer,Integer,String>> getNerSpans(List<String[]> sentWords)
{
return getLabelledSpans(sentWords, FIELD_NER_TAG, ASTERISK, false);
}
private static final String ASTERISK = "*";
private static final String HYPHEN = "-";
private static List<Triple<Integer,Integer,String>> getLabelledSpans(List<String[]> sentWords, int fieldIndex,
String defaultMarker, boolean checkEndLabel)
{
List<Triple<Integer,Integer,String>> spans = new ArrayList<>();
Stack<Triple<Integer,Integer, String>> openSpans = new Stack<>();
boolean removeStar = (ASTERISK.equals(defaultMarker));
for (int wordPos = 0; wordPos < sentWords.size(); wordPos++) {
String[] fields = sentWords.get(wordPos);
String val = getField(fields, fieldIndex);
if (!defaultMarker.equals(val)) {
int openParenIndex = -1;
int lastDelimiterIndex = -1;
for (int j = 0; j < val.length(); j++) {
char c = val.charAt(j);
boolean isDelimiter = false;
if (c == '(' || c == ')' || c == '|') {
if (openParenIndex >= 0) {
String s = val.substring(openParenIndex+1, j);
if (removeStar) {
s = starPattern.matcher(s).replaceAll("");
}
openSpans.push(new Triple<>(wordPos, -1, s));
openParenIndex = -1;
}
isDelimiter = true;
}
if (c == '(') {
openParenIndex = j;
} else if (c == ')') {
Triple<Integer, Integer, String> t = openSpans.pop();
if (checkEndLabel) {
// NOTE: end parens may cross (usually because mention either start or end on the same token
// and it is just an artifact of the ordering
String s = val.substring(lastDelimiterIndex+1, j);
if (!s.equals(t.third())) {
Stack<Triple<Integer,Integer, String>> saved = new Stack<>();
while (!s.equals(t.third())) {
// find correct match
saved.push(t);
if (openSpans.isEmpty()) {
throw new RuntimeException("Cannot find matching labelled span for " + s);
}
t = openSpans.pop();
}
while (!saved.isEmpty()) {
openSpans.push(saved.pop());
}
assert(s.equals(t.third()));
}
}
t.setSecond(wordPos);
spans.add(t);
}
if (isDelimiter) {
lastDelimiterIndex = j;
}
}
if (openParenIndex >= 0) {
String s = val.substring(openParenIndex+1, val.length());
if (removeStar) {
s = starPattern.matcher(s).replaceAll("");
}
openSpans.push(new Triple<>(wordPos, -1, s));
}
}
}
if (openSpans.size() != 0) {
throw new RuntimeException("Error extracting labelled spans for column " + fieldIndex + ": "
+ concatField(sentWords, fieldIndex));
}
return spans;
}
private CoreMap wordsToSentence(List<String[]> sentWords)
{
String sentText = concatField(sentWords, FIELD_WORD);
Annotation sentence = new Annotation(sentText);
Tree tree = wordsToParse(sentWords);
sentence.set(TreeCoreAnnotations.TreeAnnotation.class, tree);
List<Tree> leaves = tree.getLeaves();
// Check leaves == number of words
assert(leaves.size() == sentWords.size());
List<CoreLabel> tokens = new ArrayList<>(leaves.size());
sentence.set(CoreAnnotations.TokensAnnotation.class, tokens);
for (int i = 0; i < sentWords.size(); i++) {
String[] fields = sentWords.get(i);
int wordPos = Integer.parseInt(fields[FIELD_WORD_NO]);
assert(wordPos == i);
Tree leaf = leaves.get(i);
CoreLabel token = (CoreLabel) leaf.label();
tokens.add(token);
if (options.annotateTokenSpeaker) {
String speaker = fields[FIELD_SPEAKER_AUTHOR].replace("_", " ");
if (!HYPHEN.equals(speaker)) {
token.set(CoreAnnotations.SpeakerAnnotation.class, speaker);
}
}
}
if (options.annotateTokenPos) {
for (Tree leaf:leaves) {
CoreLabel token = (CoreLabel) leaf.label();
token.set(CoreAnnotations.PartOfSpeechAnnotation.class, leaf.parent(tree).value());
}
}
if (options.annotateTokenNer) {
List<Triple<Integer,Integer,String>> nerSpans = getNerSpans(sentWords);
for (Triple<Integer,Integer,String> nerSpan:nerSpans) {
int startToken = nerSpan.first();
int endToken = nerSpan.second(); /* inclusive */
String label = nerSpan.third();
for (int i = startToken; i <= endToken; i++) {
Tree leaf = leaves.get(i);
CoreLabel token = (CoreLabel) leaf.label();
String oldLabel = token.get(CoreAnnotations.NamedEntityTagAnnotation.class);
if (oldLabel != null) {
logger.warning("Replacing old named entity tag " + oldLabel + " with " + label);
}
token.set(CoreAnnotations.NamedEntityTagAnnotation.class, label);
}
}
for (CoreLabel token:tokens) {
if (!token.containsKey(CoreAnnotations.NamedEntityTagAnnotation.class)) {
token.set(CoreAnnotations.NamedEntityTagAnnotation.class, options.backgroundNerTag);
}
}
}
if (options.annotateTokenCoref) {
List<Triple<Integer,Integer,String>> corefSpans = getCorefSpans(sentWords);
for (Triple<Integer,Integer,String> corefSpan:corefSpans) {
int startToken = corefSpan.first();
int endToken = corefSpan.second(); /* inclusive */
String label = corefSpan.third();
for (int i = startToken; i <= endToken; i++) {
Tree leaf = leaves.get(i);
CoreLabel token = (CoreLabel) leaf.label();
String curLabel = label;
if (options.useCorefBIOESEncoding) {
String prefix;
if (startToken == endToken) {
prefix = "S-";
} else if (i == startToken) {
prefix = "B-";
} else if (i == endToken) {
prefix = "E-";
} else {
prefix = "I-";
}
curLabel = prefix + label;
}
String oldLabel = token.get(CorefCoreAnnotations.CorefAnnotation.class);
if (oldLabel != null) {
curLabel = oldLabel + "|" + curLabel;
}
token.set(CorefCoreAnnotations.CorefAnnotation.class, curLabel);
}
}
}
return sentence;
}
public static Annotation sentencesToDocument(String documentID, List<CoreMap> sentences)
{
String docText = null;
Annotation document = new Annotation(docText);
document.set(CoreAnnotations.DocIDAnnotation.class, documentID);
document.set(CoreAnnotations.SentencesAnnotation.class, sentences);
// Accumulate docTokens and label sentence with overall token begin/end, and sentence index annotations
List<CoreLabel> docTokens = new ArrayList<>();
int sentenceIndex = 0;
int tokenBegin = 0;
for (CoreMap sentenceAnnotation:sentences) {
List<CoreLabel> sentenceTokens = sentenceAnnotation.get(CoreAnnotations.TokensAnnotation.class);
docTokens.addAll(sentenceTokens);
int tokenEnd = tokenBegin + sentenceTokens.size();
sentenceAnnotation.set(CoreAnnotations.TokenBeginAnnotation.class, tokenBegin);
sentenceAnnotation.set(CoreAnnotations.TokenEndAnnotation.class, tokenEnd);
sentenceAnnotation.set(CoreAnnotations.SentenceIndexAnnotation.class, sentenceIndex);
sentenceIndex++;
tokenBegin = tokenEnd;
}
document.set(CoreAnnotations.TokensAnnotation.class, docTokens);
// Put in character offsets
int i = 0;
for (CoreLabel token:docTokens) {
String tokenText = token.get(CoreAnnotations.TextAnnotation.class);
token.set(CoreAnnotations.CharacterOffsetBeginAnnotation.class, i);
i+=tokenText.length();
token.set(CoreAnnotations.CharacterOffsetEndAnnotation.class, i);
i++; // Skip space
}
for (CoreMap sentenceAnnotation:sentences) {
List<CoreLabel> sentenceTokens = sentenceAnnotation.get(CoreAnnotations.TokensAnnotation.class);
sentenceAnnotation.set(CoreAnnotations.CharacterOffsetBeginAnnotation.class,
sentenceTokens.get(0).get(CoreAnnotations.CharacterOffsetBeginAnnotation.class));
sentenceAnnotation.set(CoreAnnotations.CharacterOffsetEndAnnotation.class,
sentenceTokens.get(sentenceTokens.size()-1).get(CoreAnnotations.CharacterOffsetEndAnnotation.class));
}
return document;
}
private static Tree getLowestCommonAncestor(Tree root, int startToken, int endToken)
{
Tree leftLeaf = Trees.getLeaf(root, startToken);
Tree rightLeaf = Trees.getLeaf(root, endToken);
// todo [cdm 2013]: It might be good to climb certain unaries here, like VP or S under NP, but it's not good to climb all unaries (e.g., NP under FRAG)
return Trees.getLowestCommonAncestor(leftLeaf, rightLeaf, root);
}
private static Tree getTreeNonTerminal(Tree root, int startToken, int endToken, boolean acceptPreTerminals)
{
Tree t = getLowestCommonAncestor(root, startToken, endToken);
if (t.isLeaf()) {
t = t.parent(root);
}
if (!acceptPreTerminals && t.isPreTerminal()) {
t = t.parent(root);
}
return t;
}
public void annotateDocument(CoNLLDocument document)
{
List<CoreMap> sentences = new ArrayList<>(document.sentenceWordLists.size());
for (List<String[]> sentWords:document.sentenceWordLists) {
sentences.add(wordsToSentence(sentWords));
}
Annotation docAnnotation = sentencesToDocument(document.documentIdPart /*document.documentID + "." + document.partNo */, sentences);
document.setAnnotation(docAnnotation);
// Do this here so we have updated character offsets and all
CollectionValuedMap<String, CoreMap> corefChainMap = new CollectionValuedMap<>(CollectionFactory.<CoreMap>arrayListFactory());
List<CoreMap> nerChunks = new ArrayList<>();
for (int i = 0; i < sentences.size(); i++) {
CoreMap sentence = sentences.get(i);
Tree tree = sentence.get(TreeCoreAnnotations.TreeAnnotation.class);
tree.setSpans();
List<String[]> sentWords = document.sentenceWordLists.get(i);
// Get NER chunks
List<Triple<Integer,Integer,String>> nerSpans = getNerSpans(sentWords);
for (Triple<Integer,Integer,String> nerSpan:nerSpans) {
int startToken = nerSpan.first();
int endToken = nerSpan.second(); /* inclusive */
String label = nerSpan.third();
CoreMap nerChunk = ChunkAnnotationUtils.getAnnotatedChunk(sentence, startToken, endToken+1);
nerChunk.set(CoreAnnotations.NamedEntityTagAnnotation.class, label);
nerChunk.set(CoreAnnotations.SentenceIndexAnnotation.class, sentence.get(CoreAnnotations.SentenceIndexAnnotation.class));
nerChunks.add(nerChunk);
Tree t = getTreeNonTerminal(tree, startToken, endToken, true);
if (t.getSpan().getSource() == startToken && t.getSpan().getTarget() == endToken) {
nerChunk.set(TreeCoreAnnotations.TreeAnnotation.class, t);
if (options.annotateTreeNer) {
Label tlabel = t.label();
if (tlabel instanceof CoreLabel) {
((CoreLabel) tlabel).set(NamedEntityAnnotation.class, nerChunk);
}
}
}
}
List<Triple<Integer,Integer,String>> corefSpans = getCorefSpans(sentWords);
for (Triple<Integer,Integer,String> corefSpan:corefSpans) {
int startToken = corefSpan.first();
int endToken = corefSpan.second(); /* inclusive */
String corefId = corefSpan.third();
CoreMap mention = ChunkAnnotationUtils.getAnnotatedChunk(sentence, startToken, endToken+1);
mention.set(CorefCoreAnnotations.CorefAnnotation.class, corefId);
mention.set(CoreAnnotations.SentenceIndexAnnotation.class, sentence.get(CoreAnnotations.SentenceIndexAnnotation.class));
corefChainMap.add(corefId, mention);
Tree t = getTreeNonTerminal(tree, startToken, endToken, true);
mention.set(TreeCoreAnnotations.TreeAnnotation.class, t);
if (options.annotateTreeCoref) {
Label tlabel = t.label();
if (tlabel instanceof CoreLabel) {
((CoreLabel) tlabel).set(CorefMentionAnnotation.class, mention);
}
}
}
}
document.corefChainMap = corefChainMap;
document.nerChunks = nerChunks;
}
private static final String docStart = "#begin document ";
private static final int docStartLength = docStart.length();
public CoNLLDocument readNextDocument() {
try {
List<String[]> curSentWords = new ArrayList<>();
CoNLLDocument document = null;
for (String line; (line = br.readLine()) != null; ) {
lineCnt++;
line = line.trim();
if (line.length() != 0) {
if (line.startsWith(docStart)) {
// Start of new document
if (document != null) {
logger.warning("Unexpected begin document at line (\" + filename + \",\" + lineCnt + \")");
}
document = new CoNLLDocument();
document.filename = this.filename;
document.documentIdPart = line.substring(docStartLength);
} else if (line.startsWith("#end document")) {
annotateDocument(document);
docCnt++;
return document;
// End of document
} else {
assert document != null;
String[] fields = delimiterPattern.split(line);
if (fields.length < FIELDS_MIN) {
throw new RuntimeException("Unexpected number of field " + fields.length +
", expected >= " + FIELDS_MIN + " for line (" + filename + "," + lineCnt + "): " + line);
}
String curDocId = fields[FIELD_DOC_ID];
String partNo = fields[FIELD_PART_NO];
if (document.getDocumentID() == null) {
document.setDocumentID(curDocId);
document.setPartNo(partNo);
} else {
// Check documentID didn't suddenly change on us
assert(document.getDocumentID().equals(curDocId));
assert(document.getPartNo().equals(partNo));
}
curSentWords.add(fields);
}
} else {
// Current sentence has ended, new sentence is about to be started
if (curSentWords.size() > 0) {
assert document != null;
document.addSentence(curSentWords);
curSentWords = new ArrayList<>();
}
}
}
} catch (IOException ex) {
throw new RuntimeIOException(ex);
}
return null;
}
public void close() {
IOUtils.closeIgnoringExceptions(br);
}
} // end static class DocumentIterator
public static void usage()
{
log.info("java edu.stanford.nlp.dcoref.CoNLL2011DocumentReader [-ext <extension to match>] -i <inputpath> -o <outputfile>");
}
public static Pair<Integer,Integer> getMention(Integer index, String corefG, List<CoreLabel> sentenceAnno) {
Integer i = -1;
Integer end = index;
for (CoreLabel newAnno : sentenceAnno) {
i += 1;
if (i > index) {
String corefS = newAnno.get(CorefCoreAnnotations.CorefAnnotation.class);
if (corefS != null) {
String[] allC = corefS.split("\\|");
if (Arrays.asList(allC).contains(corefG)) {
end = i;
} else {
break;
}
} else {
break;
}
}
}
return Pair.makePair(index, end);
}
public static boolean include(Map<Pair<Integer,Integer>,String> sentenceInfo,
Pair<Integer,Integer> mention,
String corefG) {
Set<Pair<Integer,Integer>> keys = sentenceInfo.keySet();
for (Pair<Integer, Integer> key : keys) {
String corefS = sentenceInfo.get(key);
if (corefS != null && corefS.equals(corefG)) {
if (key.first < mention.first && key.second.equals(mention.second)) {
return true;
}
}
}
return false;
}
public static void writeTabSep(PrintWriter pw, CoreMap sentence, CollectionValuedMap<String,CoreMap> chainmap)
{
HeadFinder headFinder = new ModCollinsHeadFinder();
List<CoreLabel> sentenceAnno = sentence.get(CoreAnnotations.TokensAnnotation.class);
Tree sentenceTree = sentence.get(TreeCoreAnnotations.TreeAnnotation.class);
Map<Pair<Integer,Integer>,String> sentenceInfo = Generics.newHashMap();
Set<Tree> sentenceSubTrees = sentenceTree.subTrees();
sentenceTree.setSpans();
Map<Pair<Integer,Integer>,Tree> treeSpanMap = Generics.newHashMap();
Map<Pair<Integer,Integer>,List<Tree>> wordSpanMap = Generics.newHashMap();
for (Tree ctree : sentenceSubTrees) {
IntPair span = ctree.getSpan();
if (span != null) {
treeSpanMap.put(Pair.makePair(span.getSource(), span.getTarget()), ctree);
wordSpanMap.put(Pair.makePair(span.getSource(), span.getTarget()), ctree.getLeaves());
}
}
String[][] finalSentence;
finalSentence = new String [sentenceAnno.size()][];
Map<Pair<Integer,Integer>,String> allHeads = Generics.newHashMap();
int index = -1;
for (CoreLabel newAnno : sentenceAnno) {
index += 1;
String word = newAnno.word();
String tag = newAnno.tag();
String cat = newAnno.ner();
String coref = newAnno.get(CorefCoreAnnotations.CorefAnnotation.class);
finalSentence[index] = new String[4];
finalSentence[index][0] = word;
finalSentence[index][1] = tag;
finalSentence[index][2] = cat;
finalSentence[index][3] = coref;
if (coref == null) {
sentenceInfo.put(Pair.makePair(index, index), coref);
finalSentence[index][3] = "O";
} else {
String[] allC = coref.split("\\|");
for (String corefG : allC) {
Pair<Integer, Integer> mention = getMention(index, corefG, sentenceAnno);
if ( ! include(sentenceInfo, mention, corefG)) {
// find largest NP in mention
sentenceInfo.put(mention, corefG);
Tree mentionTree = treeSpanMap.get(mention);
String head = null;
if (mentionTree != null) {
head = mentionTree.headTerminal(headFinder).nodeString();
} else if (mention.first.equals(mention.second)) {
head = word;
}
allHeads.put(mention, head);
}
}
if (allHeads.values().contains(word)) {
finalSentence[index][3] = "MENTION";
} else {
finalSentence[index][3] = "O";
}
}
}
for (int i=0;i<finalSentence.length;i++){
String[] wordInfo = finalSentence[i];
if (i<finalSentence.length-1){
String[] nextWordInfo = finalSentence[i+1];
if (nextWordInfo[3].equals("MENTION") && nextWordInfo[0].equals("'s")){
wordInfo[3] = "MENTION";
finalSentence[i+1][3] = "O";
}
}
pw.println(wordInfo[0] + "\t" + wordInfo[1] + "\t" + wordInfo[2] + "\t" + wordInfo[3]);
}
pw.println("");
}
public static class CorpusStats
{
IntCounter<String> mentionTreeLabelCounter = new IntCounter<>();
IntCounter<String> mentionTreeNonPretermLabelCounter = new IntCounter<>();
IntCounter<String> mentionTreePretermNonPretermNoMatchLabelCounter = new IntCounter<>();
IntCounter<String> mentionTreeMixedLabelCounter = new IntCounter<>();
IntCounter<Integer> mentionTokenLengthCounter = new IntCounter<>();
IntCounter<Integer> nerMentionTokenLengthCounter = new IntCounter<>();
int mentionExactTreeSpan = 0;
int nonPretermSpanMatches = 0;
int totalMentions = 0;
int nestedNerMentions = 0;
int nerMentions = 0;
public void process(CoNLLDocument doc)
{
List<CoreMap> sentences = doc.getAnnotation().get(CoreAnnotations.SentencesAnnotation.class);
for (String id:doc.corefChainMap.keySet()) {
Collection<CoreMap> mentions = doc.corefChainMap.get(id);
for (CoreMap m:mentions) {
CoreMap sent = sentences.get(m.get(CoreAnnotations.SentenceIndexAnnotation.class));
Tree root = sent.get(TreeCoreAnnotations.TreeAnnotation.class);
Tree t = m.get(TreeCoreAnnotations.TreeAnnotation.class);
Tree npt = t;
Tree npt2 = t;
if (npt.isPreTerminal()) {
npt = npt.parent(root);
}
int sentTokenStart = sent.get(CoreAnnotations.TokenBeginAnnotation.class);
int tokenStart = m.get(CoreAnnotations.TokenBeginAnnotation.class) - sentTokenStart;
int tokenEnd = m.get(CoreAnnotations.TokenEndAnnotation.class) - sentTokenStart;
int length = tokenEnd - tokenStart;
mentionTokenLengthCounter.incrementCount(length);
// Check if exact span
IntPair span = t.getSpan();
if (span != null) {
if (span.getSource() == tokenStart && span.getTarget() == tokenEnd - 1) {
mentionExactTreeSpan++;
} else {
logger.info("Tree span is " + span + ", tree node is " + t);
logger.info("Mention span is " + tokenStart + " " + (tokenEnd - 1) + ", mention is " + m);
}
} else {
logger.warning("No span for " + t);
}
IntPair nptSpan = npt.getSpan();
if (nptSpan.getSource() == tokenStart && nptSpan.getTarget() == tokenEnd - 1) {
nonPretermSpanMatches++;
npt2 = npt;
} else {
mentionTreePretermNonPretermNoMatchLabelCounter.incrementCount(t.label().value());
logger.info("NPT: Tree span is " + span + ", tree node is " + npt);
logger.info("NPT: Mention span is " + tokenStart + " " + (tokenEnd - 1) + ", mention is " + m);
Label tlabel = t.label();
if (tlabel instanceof CoreLabel) {
CoreMap mention = ((CoreLabel) tlabel).get(CorefMentionAnnotation.class);
String corefClusterId = mention.get(CorefCoreAnnotations.CorefAnnotation.class);
Collection<CoreMap> clusteredMentions = doc.corefChainMap.get(corefClusterId);
for (CoreMap m2:clusteredMentions) {
logger.info("NPT: Clustered mention " + m2.get(CoreAnnotations.TextAnnotation.class));
}
}
}
totalMentions++;
mentionTreeLabelCounter.incrementCount(t.label().value());
mentionTreeNonPretermLabelCounter.incrementCount(npt.label().value());
mentionTreeMixedLabelCounter.incrementCount(npt2.label().value());
Label tlabel = t.label();
if (tlabel instanceof CoreLabel) {
if (((CoreLabel) tlabel).containsKey(NamedEntityAnnotation.class)) {
// walk up tree
nerMentions++;
nerMentionTokenLengthCounter.incrementCount(length);
Tree parent = t.parent(root);
while (parent != null) {
Label plabel = parent.label();
if (plabel instanceof CoreLabel) {
if (((CoreLabel) plabel).containsKey(NamedEntityAnnotation.class)) {
logger.info("NER Mention: " + m);
CoreMap parentNerChunk = ((CoreLabel) plabel).get(NamedEntityAnnotation.class);
logger.info("Nested inside NER Mention: " + parentNerChunk);
logger.info("Nested inside NER Mention parent node: " + parent);
nestedNerMentions++;
break;
}
}
parent = parent.parent(root);
}
}
}
}
}
}
private static void appendFrac(StringBuilder sb, String label, int num, int den)
{
double frac = ((double) num)/ den;
sb.append(label).append("\t").append(frac).append("\t(").append(num).append("/").append(den).append(")");
}
private static <E> void appendIntCountStats(StringBuilder sb, String label, IntCounter<E> counts)
{
sb.append(label).append("\n");
List<E> sortedKeys = Counters.toSortedList(counts);
int total = counts.totalIntCount();
for (E key:sortedKeys) {
int count = counts.getIntCount(key);
appendFrac(sb, key.toString(), count, total);
sb.append("\n");
}
}
public String toString()
{
StringBuilder sb = new StringBuilder();
appendIntCountStats(sb, "Mention Tree Labels (no preterminals)", mentionTreeNonPretermLabelCounter);
sb.append("\n");
appendIntCountStats(sb, "Mention Tree Labels (with preterminals)", mentionTreeLabelCounter);
sb.append("\n");
appendIntCountStats(sb, "Mention Tree Labels (preterminals with parent span not match)", mentionTreePretermNonPretermNoMatchLabelCounter);
sb.append("\n");
appendIntCountStats(sb, "Mention Tree Labels (mixed)", mentionTreeMixedLabelCounter);
sb.append("\n");
appendIntCountStats(sb, "Mention Lengths", mentionTokenLengthCounter);
sb.append("\n");
appendFrac(sb, "Mention Exact Non Preterm Tree Span", nonPretermSpanMatches, totalMentions);
sb.append("\n");
appendFrac(sb, "Mention Exact Tree Span", mentionExactTreeSpan, totalMentions);
sb.append("\n");
appendFrac(sb, "NER", nerMentions, totalMentions);
sb.append("\n");
appendFrac(sb, "Nested NER", nestedNerMentions, totalMentions);
sb.append("\n");
appendIntCountStats(sb, "NER Mention Lengths", nerMentionTokenLengthCounter);
return sb.toString();
}
}
/** Reads and dumps output, mainly for debugging. */
public static void main(String[] args) throws IOException {
Properties props = StringUtils.argsToProperties(args);
boolean debug = false;
String filepath = props.getProperty("i");
String outfile = props.getProperty("o");
if (filepath == null || outfile == null) {
usage();
System.exit(-1);
}
PrintWriter fout = new PrintWriter(outfile);
logger.info("Writing to " + outfile);
String ext = props.getProperty("ext");
Options options;
if (ext != null) {
options = new Options(".*" + ext + "$");
} else {
options = new Options();
}
options.annotateTreeCoref = true;
options.annotateTreeNer = true;
CorpusStats corpusStats = new CorpusStats();
CoNLLDocumentReader reader = new CoNLLDocumentReader(filepath, options);
int docCnt = 0;
int sentCnt = 0;
int tokenCnt = 0;
for (CoNLLDocument doc; (doc = reader.getNextDocument()) != null; ) {
corpusStats.process(doc);
docCnt++;
Annotation anno = doc.getAnnotation();
if (debug) System.out.println("Document " + docCnt + ": " + anno.get(CoreAnnotations.DocIDAnnotation.class));
for (CoreMap sentence:anno.get(CoreAnnotations.SentencesAnnotation.class)) {
if (debug) System.out.println("Parse: " + sentence.get(TreeCoreAnnotations.TreeAnnotation.class));
if (debug) System.out.println("Sentence Tokens: " + StringUtils.join(sentence.get(CoreAnnotations.TokensAnnotation.class), ","));
writeTabSep(fout,sentence,doc.corefChainMap);
sentCnt++;
tokenCnt += sentence.get(CoreAnnotations.TokensAnnotation.class).size();
}
if (debug) {
for (CoreMap ner:doc.nerChunks) {
System.out.println("NER Chunk: " + ner);
}
for (String id:doc.corefChainMap.keySet()) {
System.out.println("Coref: " + id + " = " + StringUtils.join(doc.corefChainMap.get(id), ";"));
}
}
}
fout.close();
System.out.println("Total document count: " + docCnt);
System.out.println("Total sentence count: " + sentCnt);
System.out.println("Total token count: " + tokenCnt);
System.out.println(corpusStats);
}
@Override
public InputDoc nextDoc() {
CoNLLDocument conllDoc = getNextDocument();
if (conllDoc == null) return null;
Annotation anno = conllDoc.getAnnotation();
// conll doc has constituency tree but doesn't have dependency tree
setDependencyTree(anno);
List<List<Mention>> allGoldMentions = extractGoldMentions(conllDoc);
// store some useful information in docInfo for later
Map<String, String> docInfo = makeDocInfo(conllDoc);
// TODO: need to add some doc info from conllDoc
return new InputDoc(anno, docInfo, allGoldMentions, conllDoc);
}
// store any useful information for later (as features, debug, etc)
private Map<String, String> makeDocInfo(CoNLLDocument conllDoc) {
Map<String, String> docInfo = Generics.newHashMap();
docInfo.put("DOC_ID", conllDoc.documentID);
docInfo.put("DOC_PART", conllDoc.partNo);
docInfo.put("DOC_ID_PART", conllDoc.documentIdPart);
docInfo.put("DOC_FILE", conllDoc.filename);
return docInfo;
}
private void setDependencyTree(Annotation anno) {
List<CoreMap> sentences = anno.get(SentencesAnnotation.class);
for(CoreMap sentence : sentences) {
Tree tree = sentence.get(TreeAnnotation.class);
if (tree==null) continue;
SemanticGraph deps = null;
SemanticGraph basicDeps = null;
if (options.lang == Locale.CHINESE) {
final boolean threadSafe = true;
deps = SemanticGraphFactory.makeFromTree(
new ChineseGrammaticalStructure(tree, Filters.acceptFilter(), chineseHeadFinder),
SemanticGraphFactory.Mode.COLLAPSED,
GrammaticalStructure.Extras.NONE,
null);
basicDeps = SemanticGraphFactory.makeFromTree(
new ChineseGrammaticalStructure(tree, Filters.acceptFilter(), chineseHeadFinder),
SemanticGraphFactory.Mode.BASIC,
GrammaticalStructure.Extras.NONE,
null);
} else {
deps = SemanticGraphFactory.generateEnhancedDependencies(tree);
basicDeps = SemanticGraphFactory.generateUncollapsedDependencies(tree);
}
sentence.set(SemanticGraphCoreAnnotations.BasicDependenciesAnnotation.class, basicDeps);
sentence.set(SemanticGraphCoreAnnotations.EnhancedDependenciesAnnotation.class, deps);
}
}
// extract gold mentions (mention span, mention ID, cluster ID)
public List<List<Mention>> extractGoldMentions(CoNLLDocument conllDoc) {
List<CoreMap> sentences = conllDoc.getAnnotation().get(CoreAnnotations.SentencesAnnotation.class);
List<List<Mention>> allGoldMentions = new ArrayList<>();
CollectionValuedMap<String,CoreMap> corefChainMap = conllDoc.getCorefChainMap();
for (int i = 0; i < sentences.size(); i++) {
allGoldMentions.add(new ArrayList<>());
}
for (String corefIdStr : corefChainMap.keySet()) {
int id = Integer.parseInt(corefIdStr);
for (CoreMap m : corefChainMap.get(corefIdStr)) {
Mention mention = new Mention();
mention.goldCorefClusterID = id;
int sentIndex = m.get(CoreAnnotations.SentenceIndexAnnotation.class);
CoreMap sent = sentences.get(sentIndex);
mention.startIndex = m.get(CoreAnnotations.TokenBeginAnnotation.class) - sent.get(CoreAnnotations.TokenBeginAnnotation.class);
mention.endIndex = m.get(CoreAnnotations.TokenEndAnnotation.class) - sent.get(CoreAnnotations.TokenBeginAnnotation.class);
mention.originalSpan = m.get(CoreAnnotations.TokensAnnotation.class);
allGoldMentions.get(sentIndex).add(mention);
}
}
return allGoldMentions;
}
}