package edu.stanford.nlp.ie;
import edu.stanford.nlp.sequences.ListeningSequenceModel;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.ling.CoreAnnotations;
import java.util.List;
import java.util.ArrayList;
import java.util.Arrays;
/**
* This class keeps track of all labeled entities and updates
* its list whenever the label at a point gets changed. This allows
* you to not have to regenerate the list every time, which can be quite
* inefficient.
*
* @author Jenny Finkel
**/
public abstract class EntityCachingAbstractSequencePrior<IN extends CoreMap> implements ListeningSequenceModel {
protected int[] sequence;
protected final int backgroundSymbol;
protected final int numClasses;
protected final int[] possibleValues;
protected final Index<String> classIndex;
protected final List<IN> doc;
public EntityCachingAbstractSequencePrior(String backgroundSymbol, Index<String> classIndex, List<IN> doc) {
this.classIndex = classIndex;
this.backgroundSymbol = classIndex.indexOf(backgroundSymbol);
this.numClasses = classIndex.size();
this.possibleValues = new int[numClasses];
for (int i=0; i<numClasses; i++) {
possibleValues[i] = i;
}
this.doc = doc;
}
private boolean VERBOSE = false;
Entity[] entities;
@Override
public int leftWindow() {
return Integer.MAX_VALUE; // not Markovian!
}
@Override
public int rightWindow() {
return Integer.MAX_VALUE; // not Markovian!
}
@Override
public int[] getPossibleValues(int position) {
return possibleValues;
}
@Override
public double scoreOf(int[] sequence, int pos) {
return scoresOf(sequence, pos)[sequence[pos]];
}
/**
* @return the length of the sequence
*/
@Override
public int length() {
return doc.size();
}
/**
* get the number of classes in the sequence model.
*/
public int getNumClasses() {
return classIndex.size();
}
public double[] getConditionalDistribution (int[] sequence, int position) {
double[] probs = scoresOf(sequence, position);
ArrayMath.logNormalize(probs);
probs = ArrayMath.exp(probs);
//System.out.println(this);
return probs;
}
@Override
public double[] scoresOf (int[] sequence, int position) {
double[] probs = new double[numClasses];
int origClass = sequence[position];
for (int label = 0; label < numClasses; label++) {
sequence[position] = label;
updateSequenceElement(sequence, position, 0);
probs[label] = scoreOf(sequence);
}
sequence[position] = origClass;
//System.out.println(this);
return probs;
}
@Override
public void setInitialSequence(int[] initialSequence) {
this.sequence = initialSequence;
entities = new Entity[initialSequence.length];
// Arrays.fill(entities, null); // not needed; Java arrays zero initialized
for (int i = 0; i < initialSequence.length; i++) {
if (initialSequence[i] != backgroundSymbol) {
Entity entity = extractEntity(initialSequence, i);
addEntityToEntitiesArray(entity);
i += entity.words.size() - 1;
}
}
}
private void addEntityToEntitiesArray(Entity entity) {
for (int j = entity.startPosition; j < entity.startPosition + entity.words.size(); j++) {
entities[j] = entity;
}
}
/**
* extracts the entity starting at the given position
* and adds it to the entity list. returns the index
* of the last element in the entity (<b>not</b> index+1)
**/
public Entity extractEntity(int[] sequence, int position) {
Entity entity = new Entity();
entity.type = sequence[position];
entity.startPosition = position;
entity.words = new ArrayList<>();
for ( ; position < sequence.length; position++) {
if (sequence[position] == entity.type) {
String word = doc.get(position).get(CoreAnnotations.TextAnnotation.class);
entity.words.add(word);
if (position == sequence.length - 1) {
entity.otherOccurrences = otherOccurrences(entity);
}
} else {
entity.otherOccurrences = otherOccurrences(entity);
break;
}
}
return entity;
}
/**
* finds other locations in the sequence where the sequence of
* words in this entity occurs.
*/
public int[] otherOccurrences(Entity entity){
List<Integer> other = new ArrayList<>();
for (int i = 0; i < doc.size(); i++) {
if (i == entity.startPosition) { continue; }
if (matches(entity, i)) {
other.add(Integer.valueOf(i));
}
}
return toArray(other);
}
public static int[] toArray(List<Integer> list) {
int[] arr = new int[list.size()];
for (int i = 0; i < arr.length; i++) {
arr[i] = list.get(i);
}
return arr;
}
public boolean matches(Entity entity, int position) {
String word = doc.get(position).get(CoreAnnotations.TextAnnotation.class);
if (word.equalsIgnoreCase(entity.words.get(0))) {
//boolean matches = true;
for (int j = 1; j < entity.words.size(); j++) {
if (position + j >= doc.size()) {
return false;
}
String nextWord = doc.get(position+j).get(CoreAnnotations.TextAnnotation.class);
if (!nextWord.equalsIgnoreCase(entity.words.get(j))) {
return false;
}
}
return true;
}
return false;
}
public boolean joiningTwoEntities(int[] sequence, int position) {
if (sequence[position] == backgroundSymbol) { return false; }
if (position > 0 && position < sequence.length - 1) {
return (sequence[position] == sequence[position - 1] &&
sequence[position] == sequence[position + 1]);
}
return false;
}
public boolean splittingTwoEntities(int[] sequence, int position) {
if (position > 0 && position < sequence.length - 1) {
return (entities[position - 1] == entities[position + 1] &&
entities[position - 1] != null);
}
return false;
}
public boolean appendingEntity(int[] sequence, int position) {
if (position > 0) {
if (entities[position - 1] == null) { return false; }
Entity prev = entities[position - 1];
return (sequence[position] == sequence[position - 1] &&
prev.startPosition + prev.words.size() == position);
}
return false;
}
public boolean prependingEntity(int[] sequence, int position) {
if (position < sequence.length - 1) {
if (entities[position + 1] == null) { return false; }
return (sequence[position] == sequence[position + 1]);
}
return false;
}
public boolean addingSingletonEntity(int[] sequence, int position) {
if (sequence[position] == backgroundSymbol) { return false; }
if (position > 0) {
if (sequence[position - 1] == sequence[position]) { return false; }
}
if (position < sequence.length - 1) {
if (sequence[position + 1] == sequence[position]) { return false; }
}
return true;
}
public boolean removingEndOfEntity(int[] sequence, int position) {
if (position > 0) {
if (sequence[position - 1] == backgroundSymbol) { return false; }
Entity prev = entities[position - 1];
if (prev != null) {
return (prev.startPosition + prev.words.size() > position);
}
}
return false;
}
public boolean removingBeginningOfEntity(int[] sequence, int position) {
if (position < sequence.length - 1) {
if (sequence[position + 1] == backgroundSymbol) { return false; }
Entity next = entities[position + 1];
if (next != null) {
return (next.startPosition <= position);
}
}
return false;
}
public boolean noChange(int[] sequence, int position) {
if (position > 0) {
if (sequence[position - 1] == sequence[position]) {
return entities[position - 1] == entities[position];
}
}
if (position < sequence.length - 1) {
if (sequence[position + 1] == sequence[position]) {
return entities[position] == entities[position + 1];
}
}
// actually, can't tell. either no change, or singleton
// changed type
return false;
}
@Override
public void updateSequenceElement(int[] sequence, int position, int oldVal) {
if (VERBOSE) System.out.println("changing position "+position+" from " +classIndex.get(oldVal)+" to "+classIndex.get(sequence[position]));
this.sequence = sequence;
// no change?
if (noChange(sequence, position)) {
if (VERBOSE) System.out.println("no change");
if (VERBOSE) System.out.println(this);
return;
}
// are we joining 2 entities?
else if (joiningTwoEntities(sequence, position)) {
if (VERBOSE) System.out.println("joining 2 entities");
Entity newEntity = new Entity();
Entity prev = entities[position - 1];
Entity next = entities[position + 1];
newEntity.startPosition = prev.startPosition;
newEntity.words = new ArrayList<>();
newEntity.words.addAll(prev.words);
String word = doc.get(position).get(CoreAnnotations.TextAnnotation.class);
newEntity.words.add(word);
newEntity.words.addAll(next.words);
newEntity.type = sequence[position];
List<Integer> other = new ArrayList<>();
for (int i = 0; i < prev.otherOccurrences.length; i++) {
int pos = prev.otherOccurrences[i];
if (matches(newEntity, pos)) {
other.add(Integer.valueOf(pos));
}
}
newEntity.otherOccurrences = toArray(other);
addEntityToEntitiesArray(newEntity);
if (VERBOSE) System.out.println(this);
return;
}
// are we splitting up an entity?
else if (splittingTwoEntities(sequence, position)) {
if (VERBOSE) System.out.println("splitting into 2 entities");
Entity entity = entities[position];
Entity prev = new Entity();
prev.type = entity.type;
prev.startPosition = entity.startPosition;
prev.words = new ArrayList<>(entity.words.subList(0, position - entity.startPosition));
prev.otherOccurrences = otherOccurrences(prev);
addEntityToEntitiesArray(prev);
Entity next = new Entity();
next.type = entity.type;
next.startPosition = position + 1;
next.words = new ArrayList<>(entity.words.subList(position - entity.startPosition + 1, entity.words.size()));
next.otherOccurrences = otherOccurrences(next);
addEntityToEntitiesArray(next);
if (sequence[position] == backgroundSymbol) {
entities[position] = null;
} else {
Entity newEntity = new Entity();
newEntity.startPosition = position;
newEntity.type = sequence[position];
newEntity.words = new ArrayList<>();
String word = doc.get(position).get(CoreAnnotations.TextAnnotation.class);
newEntity.words.add(word);
newEntity.otherOccurrences = otherOccurrences(newEntity);
entities[position] = newEntity;
}
if (VERBOSE) System.out.println(this);
return;
}
// are we prepending to an entity ?
else if (prependingEntity(sequence, position)) {
if (VERBOSE) System.out.println("prepending entity");
Entity newEntity = new Entity();
Entity next = entities[position + 1];
newEntity.startPosition = position;
newEntity.words = new ArrayList<>();
String word = doc.get(position).get(CoreAnnotations.TextAnnotation.class);
newEntity.words.add(word);
newEntity.words.addAll(next.words);
newEntity.type = sequence[position];
//List<Integer> other = new ArrayList<Integer>();
newEntity.otherOccurrences = otherOccurrences(newEntity);
addEntityToEntitiesArray(newEntity);
if (removingEndOfEntity(sequence, position)) {
if (VERBOSE) System.out.println(" ... and removing end of previous entity.");
Entity prev = entities[position - 1];
prev.words.remove(prev.words.size()-1);
prev.otherOccurrences = otherOccurrences(prev);
}
if (VERBOSE) System.out.println(this);
return;
}
// are we appending to an entity ?
else if (appendingEntity(sequence, position)) {
if (VERBOSE) System.out.println("appending entity");
Entity newEntity = new Entity();
Entity prev = entities[position - 1];
newEntity.startPosition = prev.startPosition;
newEntity.words = new ArrayList<>();
newEntity.words.addAll(prev.words);
String word = doc.get(position).get(CoreAnnotations.TextAnnotation.class);
newEntity.words.add(word);
newEntity.type = sequence[position];
List<Integer> other = new ArrayList<>();
for (int i = 0; i < prev.otherOccurrences.length; i++) {
int pos = prev.otherOccurrences[i];
if (matches(newEntity, pos)) {
other.add(Integer.valueOf(pos));
}
}
newEntity.otherOccurrences = toArray(other);
addEntityToEntitiesArray(newEntity);
if (removingBeginningOfEntity(sequence, position)) {
if (VERBOSE) System.out.println(" ... and removing beginning of next entity.");
entities[position + 1].words.remove(0);
entities[position + 1].startPosition++;
}
if (VERBOSE) System.out.println(this);
return;
}
// adding new singleton entity
else if (addingSingletonEntity(sequence, position)) {
Entity newEntity = new Entity();
if (VERBOSE) System.out.println("adding singleton entity");
newEntity.startPosition = position;
newEntity.words = new ArrayList<>();
String word = doc.get(position).get(CoreAnnotations.TextAnnotation.class);
newEntity.words.add(word);
newEntity.type = sequence[position];
newEntity.otherOccurrences = otherOccurrences(newEntity);
addEntityToEntitiesArray(newEntity);
if (removingEndOfEntity(sequence, position)) {
if (VERBOSE) System.out.println(" ... and removing end of previous entity.");
Entity prev = entities[position - 1];
prev.words.remove(prev.words.size()-1);
prev.otherOccurrences = otherOccurrences(prev);
}
if (removingBeginningOfEntity(sequence, position)) {
if (VERBOSE) System.out.println(" ... and removing beginning of next entity.");
entities[position + 1].words.remove(0);
entities[position + 1].startPosition++;
}
if (VERBOSE) System.out.println(this);
return;
}
// are splitting off the prev entity?
else if (removingEndOfEntity(sequence, position)) {
if (VERBOSE) System.out.println("splitting off prev entity");
Entity prev = entities[position - 1];
prev.words.remove(prev.words.size() - 1);
prev.otherOccurrences = otherOccurrences(prev);
entities[position] = null;
}
// are we splitting off the next entity?
else if (removingBeginningOfEntity(sequence, position)) {
if (VERBOSE) System.out.println("splitting off next entity");
Entity next = entities[position + 1];
next.words.remove(0);
next.startPosition++;
next.otherOccurrences = otherOccurrences(next);
entities[position] = null;
} else {
entities[position] = null;
}
if (VERBOSE) System.out.println(this);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < entities.length; i++) {
sb.append(i);
sb.append("\t");
String word = doc.get(i).get(CoreAnnotations.TextAnnotation.class);
sb.append(word);
sb.append("\t");
sb.append(classIndex.get(sequence[i]));
if (entities[i] != null) {
sb.append("\t");
sb.append(entities[i].toString(classIndex));
}
sb.append("\n");
}
return sb.toString();
}
public String toString(int pos) {
StringBuilder sb = new StringBuilder();
for (int i = Math.max(0, pos - 10); i < Math.min(entities.length, pos + 10); i++) {
sb.append(i);
sb.append("\t");
String word = doc.get(i).get(CoreAnnotations.TextAnnotation.class);
sb.append(word);
sb.append("\t");
sb.append(classIndex.get(sequence[i]));
if (entities[i] != null) {
sb.append("\t");
sb.append(entities[i].toString(classIndex));
}
sb.append("\n");
}
return sb.toString();
}
}
class Entity {
public int startPosition;
public List<String> words;
public int type;
/**
* the beginning index of other locations where this sequence of
* words appears.
*/
public int[] otherOccurrences;
public String toString(Index<String> classIndex) {
StringBuilder sb = new StringBuilder();
sb.append("\"");
sb.append(StringUtils.join(words, " "));
sb.append("\" start: ");
sb.append(startPosition);
sb.append(" type: ");
sb.append(classIndex.get(type));
sb.append(" other_occurrences: ");
sb.append(Arrays.toString(otherOccurrences));
return sb.toString();
}
}