package edu.stanford.nlp.ie;
import edu.stanford.nlp.util.logging.Redwood;
import edu.stanford.nlp.sequences.ListeningSequenceModel;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.ling.CoreAnnotations;
import java.util.List;
import java.util.ArrayList;
import java.util.Arrays;
/**
* This class keeps track of all labeled entities and updates the
* its list whenever the label at a point gets changed. This allows
* you to not have to regenerate the list every time, which can be quite
* inefficient.
*
* @author Mengqiu Wang
**/
public abstract class EntityCachingAbstractSequencePriorBIO <IN extends CoreMap> implements ListeningSequenceModel {
/** A logger for this class */
private static Redwood.RedwoodChannels log = Redwood.channels(EntityCachingAbstractSequencePriorBIO.class);
protected int[] sequence;
protected final int backgroundSymbol;
protected final int numClasses;
protected final int[] possibleValues;
protected final Index<String> classIndex;
protected final Index<String> tagIndex;
private final List<String> wordDoc;
public EntityCachingAbstractSequencePriorBIO(String backgroundSymbol, Index<String> classIndex, Index<String> tagIndex, List<IN> doc) {
this.classIndex = classIndex;
this.tagIndex = tagIndex;
this.backgroundSymbol = classIndex.indexOf(backgroundSymbol);
this.numClasses = classIndex.size();
this.possibleValues = new int[numClasses];
for (int i=0; i<numClasses; i++) {
possibleValues[i] = i;
}
this.wordDoc = new ArrayList<>(doc.size());
for (IN w: doc) {
wordDoc.add(w.get(CoreAnnotations.TextAnnotation.class));
}
}
private boolean VERBOSE = false;
EntityBIO[] entities;
@Override
public int leftWindow() {
return Integer.MAX_VALUE; // not Markovian!
}
@Override
public int rightWindow() {
return Integer.MAX_VALUE; // not Markovian!
}
@Override
public int[] getPossibleValues(int position) {
return possibleValues;
}
@Override
public double scoreOf(int[] sequence, int pos) {
return scoresOf(sequence, pos)[sequence[pos]];
}
/**
* @return the length of the sequence
*/
@Override
public int length() {
return wordDoc.size();
}
/**
* get the number of classes in the sequence model.
*/
public int getNumClasses() {
return classIndex.size();
}
public double[] getConditionalDistribution (int[] sequence, int position) {
double[] probs = scoresOf(sequence, position);
ArrayMath.logNormalize(probs);
probs = ArrayMath.exp(probs);
//System.out.println(this);
return probs;
}
@Override
public double[] scoresOf (int[] sequence, int position) {
double[] probs = new double[numClasses];
int origClass = sequence[position];
int oldVal = origClass;
// if (BisequenceEmpiricalNERPrior.debugIndices.indexOf(position) != -1)
// EmpiricalNERPriorBIO.DEBUG = true;
for (int label = 0; label < numClasses; label++) {
if (label != origClass) {
sequence[position] = label;
updateSequenceElement(sequence, position, oldVal);
probs[label] = scoreOf(sequence);
oldVal = label;
// if (BisequenceEmpiricalNERPrior.debugIndices.indexOf(position) != -1)
// System.out.println(this);
}
}
sequence[position] = origClass;
updateSequenceElement(sequence, position, oldVal);
probs[origClass] = scoreOf(sequence);
// EmpiricalNERPriorBIO.DEBUG = false;
return probs;
}
@Override
public void setInitialSequence(int[] initialSequence) {
this.sequence = initialSequence;
entities = new EntityBIO[initialSequence.length];
// Arrays.fill(entities, null); // not needed; Java arrays zero initialized
for (int i = 0; i < initialSequence.length; i++) {
if (initialSequence[i] != backgroundSymbol) {
String rawTag = classIndex.get(sequence[i]);
String[] parts = rawTag.split("-");
//TODO(mengqiu) this needs to be updated, so that initial can be I as well
if (parts[0].equals("B")) { // B-
EntityBIO entity = extractEntity(initialSequence, i, parts[1]);
addEntityToEntitiesArray(entity);
i += entity.words.size() - 1;
}
}
}
}
private void addEntityToEntitiesArray(EntityBIO entity) {
for (int j = entity.startPosition; j < entity.startPosition + entity.words.size(); j++) {
entities[j] = entity;
}
}
/**
* extracts the entity starting at the given position
* and adds it to the entity list. returns the index
* of the last element in the entity (<b>not</b> index+1)
**/
public EntityBIO extractEntity(int[] sequence, int position, String tag) {
EntityBIO entity = new EntityBIO();
entity.type = tagIndex.indexOf(tag);
entity.startPosition = position;
entity.words = new ArrayList<>();
entity.words.add(wordDoc.get(position));
int pos = position + 1;
for ( ; pos < sequence.length; pos++) {
String rawTag = classIndex.get(sequence[pos]);
String[] parts = rawTag.split("-");
if (parts[0].equals("I") && parts[1].equals(tag)) {
String word = wordDoc.get(pos);
entity.words.add(word);
} else {
break;
}
}
entity.otherOccurrences = otherOccurrences(entity);
return entity;
}
/**
* finds other locations in the sequence where the sequence of
* words in this entity occurs.
*/
public int[] otherOccurrences(EntityBIO entity){
List<Integer> other = new ArrayList<>();
for (int i = 0; i < wordDoc.size(); i++) {
if (i == entity.startPosition) { continue; }
if (matches(entity, i)) {
other.add(Integer.valueOf(i));
}
}
return toArray(other);
}
public static int[] toArray(List<Integer> list) {
int[] arr = new int[list.size()];
for (int i = 0; i < arr.length; i++) {
arr[i] = list.get(i);
}
return arr;
}
public boolean matches(EntityBIO entity, int position) {
String word = wordDoc.get(position);
if (word.equalsIgnoreCase(entity.words.get(0))) {
for (int j = 1; j < entity.words.size(); j++) {
if (position + j >= wordDoc.size()) {
return false;
}
String nextWord = wordDoc.get(position+j);
if (!nextWord.equalsIgnoreCase(entity.words.get(j))) {
return false;
}
}
return true;
}
return false;
}
@Override
public void updateSequenceElement(int[] sequence, int position, int oldVal) {
this.sequence = sequence;
if (sequence[position] == oldVal)
return;
if (VERBOSE) log.info("changing position "+position+" from " +classIndex.get(oldVal)+" to "+classIndex.get(sequence[position]));
if (sequence[position] == backgroundSymbol) { // new tag is O
String oldRawTag = classIndex.get(oldVal);
String[] oldParts = oldRawTag.split("-");
if (oldParts[0].equals("B")) { // old tag was a B, current entity definitely affected, also check next one
EntityBIO entity = entities[position];
if (entity == null)
throw new RuntimeException("oldTag starts with B, entity at position should not be null");
// remove entities for all words affected by this entity
for (int i=0; i < entity.words.size(); i++) {
entities[position+i] = null;
}
} else { // old tag was a I, check previous one
if (entities[position] != null) { // this was part of an entity, shortened
if (VERBOSE) log.info("splitting off prev entity");
EntityBIO oldEntity = entities[position];
int oldLen = oldEntity.words.size();
int offset = position - oldEntity.startPosition;
List<String> newWords = new ArrayList<>();
for (int i=0; i<offset; i++) {
newWords.add(oldEntity.words.get(i));
}
oldEntity.words = newWords;
oldEntity.otherOccurrences = otherOccurrences(oldEntity);
// need to clean any remaining entity
for (int i=0 ; i < oldLen - offset; i++) {
entities[position+i] = null;
}
if (VERBOSE && position > 0)
log.info("position:" + position +", entities[position-1] = " + entities[position-1].toString(tagIndex));
} // otherwise, non-entity part I-xxx -> O, no enitty affected
}
} else {
String rawTag = classIndex.get(sequence[position]);
String[] parts = rawTag.split("-");
if (parts[0].equals("B")) { // new tag is B
if (oldVal == backgroundSymbol) { // start a new entity, may merge with the next word
EntityBIO entity = extractEntity(sequence, position, parts[1]);
addEntityToEntitiesArray(entity);
} else {
String oldRawTag = classIndex.get(oldVal);
String[] oldParts = oldRawTag.split("-");
if (oldParts[0].equals("B")) { // was a different B-xxx
EntityBIO oldEntity = entities[position];
if (oldEntity.words.size() > 1) { // remove all old entity, add new singleton
for (int i=0; i< oldEntity.words.size(); i++)
entities[position+i] = null;
EntityBIO entity = extractEntity(sequence, position, parts[1]);
addEntityToEntitiesArray(entity);
} else { // extract entity
EntityBIO entity = extractEntity(sequence, position, parts[1]);
addEntityToEntitiesArray(entity);
}
} else { // was I
EntityBIO oldEntity = entities[position];
if (oldEntity != null) {// break old entity
int oldLen = oldEntity.words.size();
int offset = position - oldEntity.startPosition;
List<String> newWords = new ArrayList<>();
for (int i=0; i<offset; i++) {
newWords.add(oldEntity.words.get(i));
}
oldEntity.words = newWords;
oldEntity.otherOccurrences = otherOccurrences(oldEntity);
// need to clean any remaining entity
for (int i=0 ; i < oldLen - offset; i++) {
entities[position+i] = null;
}
}
EntityBIO entity = extractEntity(sequence, position, parts[1]);
addEntityToEntitiesArray(entity);
}
}
} else { // new tag is I
if (oldVal == backgroundSymbol) { // check if previous entity extends into this one
if (position > 0) {
if (entities[position-1] != null) {
String oldTag = tagIndex.get(entities[position-1].type);
EntityBIO entity = extractEntity(sequence, position-1-entities[position-1].words.size()+1, oldTag);
addEntityToEntitiesArray(entity);
}
}
} else {
String oldRawTag = classIndex.get(oldVal);
String[] oldParts = oldRawTag.split("-");
if (oldParts[0].equals("B")) { // was a B, clean the B entity first, then check if previous is an entity
EntityBIO oldEntity = entities[position];
for (int i=0; i<oldEntity.words.size(); i++)
entities[position+i] = null;
if (position > 0) {
if (entities[position-1] != null) {
String oldTag = tagIndex.get(entities[position-1].type);
if (VERBOSE)
log.info("position:" + position +", entities[position-1] = " + entities[position-1].toString(tagIndex));
EntityBIO entity = extractEntity(sequence, position-1-entities[position-1].words.size()+1, oldTag);
addEntityToEntitiesArray(entity);
}
}
} else { // was a differnt I-xxx,
if (entities[position] != null) { // shorten the previous one, remove any additional parts
EntityBIO oldEntity = entities[position];
int oldLen = oldEntity.words.size();
int offset = position - oldEntity.startPosition;
List<String> newWords = new ArrayList<>();
for (int i=0; i<offset; i++) {
newWords.add(oldEntity.words.get(i));
}
oldEntity.words = newWords;
oldEntity.otherOccurrences = otherOccurrences(oldEntity);
// need to clean any remaining entity
for (int i=0 ; i < oldLen - offset; i++) {
entities[position+i] = null;
}
} else { // re-calc entity of the previous entity if exist
if (position > 0) {
if (entities[position-1] != null) {
String oldTag = tagIndex.get(entities[position-1].type);
EntityBIO entity = extractEntity(sequence, position-1-entities[position-1].words.size()+1, oldTag);
addEntityToEntitiesArray(entity);
}
}
}
}
}
}
}
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < entities.length; i++) {
sb.append(i);
sb.append('\t');
String word = wordDoc.get(i);
sb.append(word);
sb.append('\t');
sb.append(classIndex.get(sequence[i]));
if (entities[i] != null) {
sb.append('\t');
sb.append(entities[i].toString(tagIndex));
}
sb.append('\n');
}
return sb.toString();
}
public String toString(int pos) {
StringBuilder sb = new StringBuilder();
for (int i = Math.max(0, pos - 3); i < Math.min(entities.length, pos + 3); i++) {
sb.append(i);
sb.append('\t');
String word = wordDoc.get(i);
sb.append(word);
sb.append('\t');
sb.append(classIndex.get(sequence[i]));
if (entities[i] != null) {
sb.append('\t');
sb.append(entities[i].toString(tagIndex));
}
sb.append('\n');
}
return sb.toString();
}
}
class EntityBIO {
public int startPosition;
public List<String> words;
public int type;
/**
* the beginning index of other locations where this sequence of
* words appears.
*/
public int[] otherOccurrences;
public String toString(Index<String> tagIndex) {
StringBuilder sb = new StringBuilder();
sb.append('"');
sb.append(StringUtils.join(words, " "));
sb.append("\" start: ");
sb.append(startPosition);
sb.append(" type: ");
sb.append(tagIndex.get(type));
sb.append(" other_occurrences: ");
sb.append(Arrays.toString(otherOccurrences));
return sb.toString();
}
}