package edu.stanford.nlp.coref.neural; import java.util.ArrayList; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Set; import org.ejml.simple.SimpleMatrix; import edu.stanford.nlp.coref.data.Document; import edu.stanford.nlp.coref.data.Mention; import edu.stanford.nlp.ling.CoreLabel; import edu.stanford.nlp.neural.Embedding; import edu.stanford.nlp.neural.NeuralUtils; import edu.stanford.nlp.semgraph.SemanticGraphEdge; /** * Extracts word-embedding features from mentions. * @author Kevin Clark */ public class EmbeddingExtractor { private final boolean conll; private final Embedding staticWordEmbeddings; private final Embedding tunedWordEmbeddings; public EmbeddingExtractor(boolean conll, Embedding staticWordEmbeddings, Embedding tunedWordEmbeddings) { this.conll = conll; this.staticWordEmbeddings = staticWordEmbeddings; this.tunedWordEmbeddings = tunedWordEmbeddings; } public SimpleMatrix getDocumentEmbedding(Document document) { if (!conll) { return new SimpleMatrix(staticWordEmbeddings.getEmbeddingSize(), 1); } List<CoreLabel> words = new ArrayList<>(); Set<Integer> seenSentences = new HashSet<>(); for (Mention m : document.predictedMentionsByID.values()) { if (!seenSentences.contains(m.sentNum)) { seenSentences.add(m.sentNum); words.addAll(m.sentenceWords); } } return getAverageEmbedding(words); } public SimpleMatrix getMentionEmbeddings(Mention m, SimpleMatrix docEmbedding) { Iterator<SemanticGraphEdge> depIterator = m.enhancedDependency.incomingEdgeIterator(m.headIndexedWord); SemanticGraphEdge depRelation = depIterator.hasNext() ? depIterator.next() : null; return NeuralUtils.concatenate( getAverageEmbedding(m.sentenceWords, m.startIndex, m.endIndex), getAverageEmbedding(m.sentenceWords, m.startIndex - 5, m.startIndex), getAverageEmbedding(m.sentenceWords, m.endIndex, m.endIndex + 5), getAverageEmbedding(m.sentenceWords.subList(0, m.sentenceWords.size() - 1)), docEmbedding, getWordEmbedding(m.sentenceWords, m.headIndex), getWordEmbedding(m.sentenceWords, m.startIndex), getWordEmbedding(m.sentenceWords, m.endIndex - 1), getWordEmbedding(m.sentenceWords, m.startIndex - 1), getWordEmbedding(m.sentenceWords, m.endIndex), getWordEmbedding(m.sentenceWords, m.startIndex - 2), getWordEmbedding(m.sentenceWords, m.endIndex + 1), getWordEmbedding(depRelation == null ? null : depRelation.getSource().word()) ); } private SimpleMatrix getAverageEmbedding(List<CoreLabel> words) { SimpleMatrix emb = new SimpleMatrix(staticWordEmbeddings.getEmbeddingSize(), 1); for (CoreLabel word : words) { emb = emb.plus(getStaticWordEmbedding(word.word())); } return emb.divide(Math.max(1, words.size())); } private SimpleMatrix getAverageEmbedding(List<CoreLabel> sentence, int start, int end) { return getAverageEmbedding(sentence.subList(Math.max(Math.min(start, sentence.size() - 1), 0), Math.max(Math.min(end, sentence.size() - 1), 0))); } private SimpleMatrix getWordEmbedding(List<CoreLabel> sentence, int i) { return getWordEmbedding(i < 0 || i >= sentence.size() ? null : sentence.get(i).word()); } public SimpleMatrix getWordEmbedding(String word) { word = normalizeWord(word); return tunedWordEmbeddings.containsWord(word) ? tunedWordEmbeddings.get(word) : staticWordEmbeddings.get(word); } public SimpleMatrix getStaticWordEmbedding(String word) { return staticWordEmbeddings.get(normalizeWord(word)); } private static String normalizeWord(String w) { if (w == null) { return "<missing>"; } else if (w.equals("/.")) { return "."; } else if (w.equals("/?")) { return "?"; } else if (w.equals("-LRB-")) { return "("; } else if (w.equals("-RRB-")) { return ")"; } else if (w.equals("-LCB-")) { return "{"; } else if (w.equals("-RCB-")) { return "}"; } else if (w.equals("-LSB-")) { return "["; } else if (w.equals("-RSB-")) { return "]"; } return w.replaceAll("\\d", "0").toLowerCase(); } }