/*
* Copyright 2013
* Ubiquitous Knowledge Processing (UKP) Lab
* Technische Universität Darmstadt
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package de.tudarmstadt.ukp.dkpro.core.clearnlp;
import static java.util.Arrays.asList;
import static org.apache.commons.io.IOUtils.closeQuietly;
import static org.apache.uima.fit.util.JCasUtil.select;
import static org.apache.uima.fit.util.JCasUtil.selectCovered;
import static org.apache.uima.util.Level.INFO;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.zip.GZIPInputStream;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.component.JCasAnnotator_ImplBase;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.descriptor.TypeCapability;
import org.apache.uima.fit.util.FSCollectionFactory;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;
import com.clearnlp.classification.model.StringModel;
import com.clearnlp.component.AbstractComponent;
import com.clearnlp.component.AbstractStatisticalComponent;
import com.clearnlp.dependency.DEPArc;
import com.clearnlp.dependency.DEPLib;
import com.clearnlp.dependency.DEPNode;
import com.clearnlp.dependency.DEPTree;
import com.clearnlp.nlp.NLPGetter;
import com.clearnlp.nlp.NLPMode;
import de.tudarmstadt.ukp.dkpro.core.api.parameter.ComponentParameters;
import de.tudarmstadt.ukp.dkpro.core.api.resources.CasConfigurableProviderBase;
import de.tudarmstadt.ukp.dkpro.core.api.resources.CasConfigurableStreamProviderBase;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token;
import de.tudarmstadt.ukp.dkpro.core.api.semantics.type.SemArg;
import de.tudarmstadt.ukp.dkpro.core.api.semantics.type.SemArgLink;
import de.tudarmstadt.ukp.dkpro.core.api.semantics.type.SemPred;
import de.tudarmstadt.ukp.dkpro.core.api.syntax.type.dependency.Dependency;
import de.tudarmstadt.ukp.dkpro.core.api.syntax.type.dependency.ROOT;
/**
* ClearNLP semantic role labeller.
*/
@TypeCapability(
inputs = {
"de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence",
"de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token",
"de.tudarmstadt.ukp.dkpro.core.api.lexmorph.type.pos.POS",
"de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Lemma",
"de.tudarmstadt.ukp.dkpro.core.api.syntax.type.dependency.Dependency"},
outputs = {
"de.tudarmstadt.ukp.dkpro.core.api.semantics.type.SemPred",
"de.tudarmstadt.ukp.dkpro.core.api.semantics.type.SemArg"}
)
public class ClearNlpSemanticRoleLabeler
extends JCasAnnotator_ImplBase
{
/**
* Write the tag set(s) to the log when a model is loaded.
*/
public static final String PARAM_PRINT_TAGSET = ComponentParameters.PARAM_PRINT_TAGSET;
@ConfigurationParameter(name = PARAM_PRINT_TAGSET, mandatory = true, defaultValue = "false")
protected boolean printTagSet;
/**
* Use this language instead of the document language to resolve the model.
*/
public static final String PARAM_LANGUAGE = ComponentParameters.PARAM_LANGUAGE;
@ConfigurationParameter(name = PARAM_LANGUAGE, mandatory = false)
protected String language;
/**
* Variant of a model the model. Used to address a specific model if here are multiple models
* for one language.
*/
public static final String PARAM_VARIANT = ComponentParameters.PARAM_VARIANT;
@ConfigurationParameter(name = PARAM_VARIANT, mandatory = false)
protected String variant;
/**
* Location from which the predicate identifier model is read.
*/
public static final String PARAM_PRED_MODEL_LOCATION = "predModelLocation";
@ConfigurationParameter(name = PARAM_PRED_MODEL_LOCATION, mandatory = false)
protected String predModelLocation;
/**
* Location from which the roleset classification model is read.
*/
public static final String PARAM_ROLE_MODEL_LOCATION = "roleModelLocation";
@ConfigurationParameter(name = PARAM_ROLE_MODEL_LOCATION, mandatory = false)
protected String roleModelLocation;
/**
* Location from which the semantic role labeling model is read.
*/
public static final String PARAM_SRL_MODEL_LOCATION = "srlModelLocation";
@ConfigurationParameter(name = PARAM_SRL_MODEL_LOCATION, mandatory = false)
protected String srlModelLocation;
/**
* <p>Normally the arguments point only to the head words of arguments in the dependency tree.
* With this option enabled, they are expanded to the text covered by the minimal and maximal
* token offsets of all descendants (or self) of the head word.</p>
*
* <p>Warning: this parameter should be used with caution! For one, if the descentants of a
* head word cover a non-continuous region of the text, this information is lost. The arguments
* will appear to be spanning a continuous region. For another, the arguments may overlap with
* each other. E.g. if a sentence contains a relative clause with a verb, the subject of the
* main clause may be recognized as a dependent of the verb and may cause the whole main
* clause to be recorded in the argument.</p>
*/
public static final String PARAM_EXPAND_ARGUMENTS = "expandArguments";
@ConfigurationParameter(name = PARAM_EXPAND_ARGUMENTS, mandatory = true, defaultValue="false")
protected boolean expandArguments;
private CasConfigurableProviderBase<AbstractComponent> predicateFinder;
private CasConfigurableProviderBase<AbstractComponent> roleSetClassifier;
private CasConfigurableProviderBase<AbstractComponent> roleLabeller;
@Override
public void initialize(UimaContext aContext)
throws ResourceInitializationException
{
super.initialize(aContext);
predicateFinder = new CasConfigurableStreamProviderBase<AbstractComponent>()
{
{
setContextObject(ClearNlpSemanticRoleLabeler.this);
setDefault(ARTIFACT_ID, "${groupId}.clearnlp-model-pred-${language}-${variant}");
setDefault(LOCATION, "classpath:/de/tudarmstadt/ukp/dkpro/core/clearnlp/lib/"
+ "pred-${language}-${variant}.properties");
setDefault(VARIANT, "ontonotes");
setOverride(LOCATION, predModelLocation);
setOverride(LANGUAGE, language);
setOverride(VARIANT, variant);
}
@Override
protected AbstractComponent produceResource(InputStream aStream)
throws Exception
{
BufferedInputStream bis = null;
ObjectInputStream ois = null;
GZIPInputStream gis = null;
try{
gis = new GZIPInputStream(aStream);
bis = new BufferedInputStream(gis);
ois = new ObjectInputStream(bis);
AbstractComponent component = NLPGetter.getComponent(ois,
getAggregatedProperties().getProperty(LANGUAGE), NLPMode.MODE_PRED);
printTags(NLPMode.MODE_PRED, component);
return component;
}
catch (Exception e) {
throw new IOException(e);
}
finally {
closeQuietly(ois);
closeQuietly(bis);
closeQuietly(gis);
}
}
};
roleSetClassifier = new CasConfigurableStreamProviderBase<AbstractComponent>()
{
{
setContextObject(ClearNlpSemanticRoleLabeler.this);
setDefault(ARTIFACT_ID, "${groupId}.clearnlp-model-role-${language}-${variant}");
setDefault(LOCATION, "classpath:/de/tudarmstadt/ukp/dkpro/core/clearnlp/lib/"
+ "role-${language}-${variant}.properties");
setDefault(VARIANT, "ontonotes");
setOverride(LOCATION, roleModelLocation);
setOverride(LANGUAGE, language);
setOverride(VARIANT, variant);
}
@Override
protected AbstractComponent produceResource(InputStream aStream)
throws Exception
{
BufferedInputStream bis = null;
ObjectInputStream ois = null;
GZIPInputStream gis = null;
try{
gis = new GZIPInputStream(aStream);
bis = new BufferedInputStream(gis);
ois = new ObjectInputStream(bis);
AbstractComponent component = NLPGetter.getComponent(ois,
getAggregatedProperties().getProperty(LANGUAGE), NLPMode.MODE_ROLE);
printTags(NLPMode.MODE_ROLE, component);
return component;
}
catch (Exception e) {
throw new IOException(e);
}
finally {
closeQuietly(ois);
closeQuietly(bis);
closeQuietly(gis);
}
}
};
roleLabeller = new CasConfigurableStreamProviderBase<AbstractComponent>()
{
{
setContextObject(ClearNlpSemanticRoleLabeler.this);
setDefault(ARTIFACT_ID, "${groupId}.clearnlp-model-srl-${language}-${variant}");
setDefault(LOCATION, "classpath:/de/tudarmstadt/ukp/dkpro/core/clearnlp/lib/"
+ "srl-${language}-${variant}.properties");
setDefault(VARIANT, "ontonotes");
setOverride(LOCATION, srlModelLocation);
setOverride(LANGUAGE, language);
setOverride(VARIANT, variant);
}
@Override
protected AbstractComponent produceResource(InputStream aStream)
throws Exception
{
BufferedInputStream bis = null;
ObjectInputStream ois = null;
GZIPInputStream gis = null;
try{
gis = new GZIPInputStream(aStream);
bis = new BufferedInputStream(gis);
ois = new ObjectInputStream(bis);
AbstractComponent component = NLPGetter.getComponent(ois,
getAggregatedProperties().getProperty(LANGUAGE), NLPMode.MODE_SRL);
printTags(NLPMode.MODE_SRL, component);
return component;
}
catch (Exception e) {
throw new IOException(e);
}
finally {
closeQuietly(ois);
closeQuietly(bis);
closeQuietly(gis);
}
}
};
}
@Override
public void process(JCas aJCas)
throws AnalysisEngineProcessException
{
predicateFinder.configure(aJCas.getCas());
roleSetClassifier.configure(aJCas.getCas());
roleLabeller.configure(aJCas.getCas());
// Iterate over all sentences
for (Sentence sentence : select(aJCas, Sentence.class)) {
List<Token> tokens = selectCovered(aJCas, Token.class, sentence);
DEPTree tree = new DEPTree();
// Generate:
// - DEPNode
// - pos tags
// - lemma
for (int i = 0; i < tokens.size(); i++) {
Token t = tokens.get(i);
DEPNode node = new DEPNode(i + 1, tokens.get(i).getCoveredText());
node.pos = t.getPos().getPosValue();
node.lemma = t.getLemma().getValue();
tree.add(node);
}
// Generate:
// Dependency relations
for (Dependency dep : selectCovered(Dependency.class, sentence)) {
if (dep instanceof ROOT) {
// #736 ClearNlpSemanticRoleLabelerTest gets caught in infinite loop
// ClearNLP parser creates roots that do not have a head. We have to replicate
// this here to avoid running into an endless loop.
continue;
}
int headIndex = tokens.indexOf(dep.getGovernor());
int tokenIndex = tokens.indexOf(dep.getDependent());
DEPNode token = tree.get(tokenIndex + 1);
DEPNode head = tree.get(headIndex + 1);
token.setHead(head, dep.getDependencyType());
}
// For the root node
for (int i = 0; i < tokens.size(); i++) {
DEPNode parserNode = tree.get(i + 1);
if(parserNode.getLabel() == null){
int headIndex = tokens.indexOf(null);
DEPNode head = tree.get(headIndex + 1);
parserNode.setHead(head, "root");
}
}
// Do the SRL
predicateFinder.getResource().process(tree);
roleSetClassifier.getResource().process(tree);
roleLabeller.getResource().process(tree);
// Convert the results into UIMA annotations
Map<Token, SemPred> predicates = new HashMap<>();
Map<SemPred, List<SemArgLink>> predArgs = new HashMap<>();
for (int i = 0; i < tokens.size(); i++) {
DEPNode parserNode = tree.get(i + 1);
Token argumentToken = tokens.get(i);
for (DEPArc argPredArc : parserNode.getSHeads()) {
Token predToken = tokens.get(argPredArc.getNode().id - 1);
// Instantiate the semantic predicate annotation if it hasn't been done yet
SemPred pred = predicates.get(predToken);
if (pred == null) {
// Create the semantic predicate annotation itself
pred = new SemPred(aJCas, predToken.getBegin(), predToken.getEnd());
pred.setCategory(argPredArc.getNode().getFeat(DEPLib.FEAT_PB));
pred.addToIndexes();
predicates.put(predToken, pred);
// Prepare a list to store its arguments
predArgs.put(pred, new ArrayList<>());
}
// Instantiate the semantic argument annotation
SemArg arg = new SemArg(aJCas);
if (expandArguments) {
List<DEPNode> descendents = parserNode.getDescendents(Integer.MAX_VALUE)
.stream()
.map(arc -> arc.getNode())
.collect(Collectors.toList());
descendents.add(parserNode);
List<Token> descTokens = descendents.stream()
.map(node -> tokens.get(node.id - 1))
.collect(Collectors.toList());
int begin = descTokens.stream().mapToInt(t -> t.getBegin()).min().getAsInt();
int end = descTokens.stream().mapToInt(t -> t.getEnd()).max().getAsInt();
arg.setBegin(begin);
arg.setEnd(end);
}
else {
arg.setBegin(argumentToken.getBegin());
arg.setEnd(argumentToken.getEnd());
}
arg.addToIndexes();
SemArgLink link = new SemArgLink(aJCas);
link.setRole(argPredArc.getLabel());
link.setTarget(arg);
// Remember to which predicate this argument belongs
predArgs.get(pred).add(link);
}
}
for (Entry<SemPred, List<SemArgLink>> e : predArgs.entrySet()) {
e.getKey().setArguments(FSCollectionFactory.createFSArray(aJCas, e.getValue()));
}
}
}
private void printTags(String aType, AbstractComponent aComponent)
{
if (printTagSet && (aComponent instanceof AbstractStatisticalComponent)) {
AbstractStatisticalComponent component = (AbstractStatisticalComponent) aComponent;
Set<String> tagSet = new HashSet<String>();
for (StringModel model : component.getModels()) {
tagSet.addAll(asList(model.getLabels()));
}
List<String> tagList = new ArrayList<String>(tagSet);
Collections.sort(tagList);
StringBuilder sb = new StringBuilder();
sb.append("Model of " + aType + " contains [").append(tagList.size())
.append("] tags: ");
for (String tag : tagList) {
sb.append(tag);
sb.append(" ");
}
getContext().getLogger().log(INFO, sb.toString());
}
}
}