package de.berlin.hu.uima.cc.banner.trainer;
import banner.tagging.CRFTagger;
import banner.tagging.FeatureSet;
import banner.tagging.TagFormat;
import banner.types.EntityType;
import banner.types.Mention;
import banner.types.Mention.MentionType;
import banner.types.Sentence.OverlapOption;
import de.berlin.hu.banner.featuresets.KlingerLikeFeatureSet;
import de.berlin.hu.banner.util.ConfigUtil;
import de.berlin.hu.uima.util.Util;
import dragon.nlp.tool.Tagger;
import dragon.nlp.tool.lemmatiser.EngLemmatiser;
import org.apache.commons.configuration.ConfigurationException;
import org.apache.commons.configuration.HierarchicalConfiguration;
import org.apache.commons.configuration.XMLConfiguration;
import org.apache.uima.cas.CAS;
import org.apache.uima.cas.CASException;
import org.apache.uima.collection.CasConsumer_ImplBase;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;
import org.apache.uima.resource.ResourceProcessException;
import org.u_compare.shared.semantic.NamedEntity;
import org.u_compare.shared.syntactic.Sentence;
import org.uimafit.util.JCasUtil;
import java.io.File;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
/**
* @author Tim Rocktäschel
*
* This is an UIMA CAS Consumer for training BANNER.
* It expects CAS objects containing documents with tokens, sentences and mentions.
* After all CAS objects are processed, it trains a CRF and writes the model to disk.
*/
public class BannerTrainer extends CasConsumer_ImplBase{
private static final String BANNER_MODEL_OUTPUT_FILE_PARAM = "BannerModelOutputFile";
private static final String BANNER_CONFIG_FILE_PARAM = "BannerConfigFile";
private File bannerModelOutputFile;
private File bannerConfigFile;
private HierarchicalConfiguration config;
private FeatureSet featureSet;
private TagFormat tagFormat;
private Set<banner.types.Sentence> bannerSentences;
private int documentCounter;
private int numberOfEntities;
private int crfOrder;
@Override
public void initialize() throws ResourceInitializationException {
super.initialize();
bannerSentences = new HashSet<banner.types.Sentence>();
//get a handle to the file where the model should be stored
bannerModelOutputFile = new File(getConfigParameterValue(
BANNER_MODEL_OUTPUT_FILE_PARAM).toString());
bannerConfigFile = new File(getConfigParameterValue(
BANNER_CONFIG_FILE_PARAM).toString());
try {
config = new XMLConfiguration(bannerConfigFile);
} catch (ConfigurationException e) {
e.printStackTrace();
throw new ResourceInitializationException(e);
}
tagFormat = ConfigUtil.getTagFormat(config);
EngLemmatiser lemmatiser = ConfigUtil.getLemmatiser(config);
Tagger posTagger = ConfigUtil.getPosTagger(config);
Set<MentionType> mentionTypes = ConfigUtil.getMentionTypes(config);
OverlapOption sameTypeOverlapOption = ConfigUtil.getSameTypeOverlapOption(config);
OverlapOption differentTypeOverlapOption = ConfigUtil.getDifferentTypeOverlapOption(config);
crfOrder = ConfigUtil.getCRFOrder(config);
// Klinger et al. (2008) like feature set
featureSet = new KlingerLikeFeatureSet(tagFormat, lemmatiser, posTagger, null, mentionTypes, sameTypeOverlapOption, differentTypeOverlapOption);
documentCounter = 0;
numberOfEntities = 0;
}
public void processCas(CAS aCas) throws ResourceProcessException {
JCas aJCas = null;
try {
aJCas = aCas.getJCas();
} catch (CASException e) {
throw new ResourceProcessException(e);
}
Iterator<Sentence> sentenceIterator = JCasUtil.iterator(aJCas, Sentence.class);
int sentenceCounter = 0;
while (sentenceIterator.hasNext()) {
Sentence sentence = (Sentence) sentenceIterator.next();
//convert every sentence to a training example for BANNER
banner.types.Sentence bannerSentence = new banner.types.Sentence(sentenceCounter+"", documentCounter+"", sentence.getCoveredText());
//pointer to the start and end of the sentence within the document
int sentenceBegin = sentence.getBegin();
int sentenceEnd = sentence.getEnd();
//get all tokens that cover the current sentence
List<org.u_compare.shared.syntactic.Token> tokensInSentence = Util.getTokens(aJCas, sentenceBegin, sentenceEnd);
Util.tokenizeBannerSentence(bannerSentence, tokensInSentence);
assert tokensInSentence.size() == bannerSentence.getTokens().size();
Iterator<NamedEntity> entityIterator = JCasUtil.iterator(sentence, NamedEntity.class, true, true);
NamedEntity lastEntity = null;
//convert every entity into a BANNER mention
while (entityIterator.hasNext()) {
NamedEntity currentEntity = (NamedEntity) entityIterator.next();
if (!overlaps(lastEntity, currentEntity)) {
int currentEntityBegin = currentEntity.getBegin();
int currentEntityEnd = currentEntity.getEnd();
//check whether the current entity is within the sentence
if (currentEntityBegin < sentenceEnd && currentEntityEnd <= sentenceEnd) {
//get the position within the sentence
int tokenPositionBegin = getTokenPositionBegin(currentEntityBegin, tokensInSentence);
int tokenPositionEnd = getTokenPositionEnd(currentEntityEnd, tokensInSentence);
//add every mention to the training sentence
Mention mention = new Mention(bannerSentence, tokenPositionBegin, tokenPositionEnd + 1, EntityType.getType(currentEntity.getEntityType()), MentionType.Required);
// mention.setProbability(1.0);
bannerSentence.addMention(mention);
numberOfEntities++;
} else {
break;
}
} else {
System.out.println("Probable annotation error: " + lastEntity.getCoveredText() + " overlaps " + currentEntity.getCoveredText());
}
lastEntity = currentEntity;
}
//add to training examples
bannerSentences.add(bannerSentence);
sentenceCounter++;
}
documentCounter++;
}
private boolean overlaps(NamedEntity lastEntity,
NamedEntity currentEntity) {
if (lastEntity != null) {
if (currentEntity.getBegin() >= lastEntity.getBegin()
&& currentEntity.getEnd() <= lastEntity.getEnd()) {
return true;
}
}
return false;
}
/**
* @return the index of the token denoting the start of the entity
*/
private int getTokenPositionBegin(int currentEntityBegin,
List<org.u_compare.shared.syntactic.Token> tokensInSentence) {
for (int i = 0; i < tokensInSentence.size(); i++) {
org.u_compare.shared.syntactic.Token token = tokensInSentence.get(i);
if (token.getBegin() <= currentEntityBegin && currentEntityBegin < token.getEnd()) {
return i;
}
}
System.out.println(currentEntityBegin);
for (org.u_compare.shared.syntactic.Token token : tokensInSentence) {
System.out.println(token.getBegin() + "\t" + token.getEnd());
}
throw new IllegalArgumentException();
}
/**
* @return the index of the token denoting the end of the entity
*/
private int getTokenPositionEnd(int currentNamedEnd,
List<org.u_compare.shared.syntactic.Token> tokensInSentence) {
for (int i = 0; i < tokensInSentence.size(); i++) {
org.u_compare.shared.syntactic.Token token = tokensInSentence.get(i);
if (token.getBegin() < currentNamedEnd && currentNamedEnd <= token.getEnd()) {
return i;
}
}
throw new IllegalArgumentException();
}
//this method is invoked when all CAS objects passed through
//FIXME: not true! use SourceDocumentAnnotation instead of this method
@Override
public void destroy() {
System.out.println("Number of training sentences: " + bannerSentences.size());
System.out.println("Number of entities: " + numberOfEntities);
System.out.println("Training data loaded, starting training");
// CRFTaggerStochasticGradient tagger;
CRFTagger tagger;
try {
// tagger = CRFTaggerStochasticGradient.train(bannerSentences, crfOrder, tagFormat, featureSet);
tagger = CRFTagger.train(bannerSentences, crfOrder, tagFormat, featureSet);
System.out.println("Training complete, saving model");
tagger.describe("model_describe.txt");
tagger.write(bannerModelOutputFile);
//FIXME: throwable is to general!
} catch (Throwable e) {
e.printStackTrace();
}
}
}