/*
* Copyright 2016
* Ubiquitous Knowledge Processing (UKP) Lab
* Technische Universität Darmstadt
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package de.tudarmstadt.ukp.dkpro.core.opennlp;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Collections;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.commons.io.IOUtils;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.component.JCasConsumer_ImplBase;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;
import de.tudarmstadt.ukp.dkpro.core.api.parameter.ComponentParameters;
import de.tudarmstadt.ukp.dkpro.core.opennlp.internal.CasNameSampleStream;
import opennlp.tools.ml.BeamSearch;
import opennlp.tools.ml.EventTrainer;
import opennlp.tools.ml.maxent.GISTrainer;
import opennlp.tools.ml.maxent.quasinewton.QNTrainer;
import opennlp.tools.ml.perceptron.PerceptronTrainer;
import opennlp.tools.ml.perceptron.SimplePerceptronSequenceTrainer;
import opennlp.tools.namefind.BilouCodec;
import opennlp.tools.namefind.BioCodec;
import opennlp.tools.namefind.NameFinderME;
import opennlp.tools.namefind.TokenNameFinderFactory;
import opennlp.tools.namefind.TokenNameFinderModel;
import opennlp.tools.util.SequenceCodec;
import opennlp.tools.util.TrainingParameters;
/**
* Train a named entity recognizer model for OpenNLP.
*/
public class OpenNlpNamedEntityRecognizerTrainer
extends JCasConsumer_ImplBase
{
public static enum SequenceEncoding {
BIO(BioCodec.class), BILOU(BilouCodec.class);
private Class<? extends SequenceCodec<String>> codec;
SequenceEncoding(Class<? extends SequenceCodec<String>> aCodec)
{
codec = aCodec;
}
private SequenceCodec<String> getCodec()
{
try {
return codec.newInstance();
}
catch (InstantiationException | IllegalAccessException e) {
throw new IllegalStateException(e);
}
}
}
public static final String PARAM_LANGUAGE = ComponentParameters.PARAM_LANGUAGE;
@ConfigurationParameter(name = PARAM_LANGUAGE, mandatory = true)
private String language;
public static final String PARAM_TARGET_LOCATION = ComponentParameters.PARAM_TARGET_LOCATION;
@ConfigurationParameter(name = PARAM_TARGET_LOCATION, mandatory = true)
private File targetLocation;
/**
* @see GISTrainer#MAXENT_VALUE
* @see QNTrainer#MAXENT_QN_VALUE
* @see PerceptronTrainer#PERCEPTRON_VALUE
* @see SimplePerceptronSequenceTrainer#PERCEPTRON_SEQUENCE_VALUE
*/
public static final String PARAM_ALGORITHM = "algorithm";
@ConfigurationParameter(name = PARAM_ALGORITHM, mandatory = true, defaultValue = PerceptronTrainer.PERCEPTRON_VALUE)
private String algorithm;
public static final String PARAM_TRAINER_TYPE = "trainerType";
@ConfigurationParameter(name = PARAM_TRAINER_TYPE, mandatory = true, defaultValue = EventTrainer.EVENT_VALUE)
private String trainerType;
public static final String PARAM_ITERATIONS = "iterations";
@ConfigurationParameter(name = PARAM_ITERATIONS, mandatory = true, defaultValue = "300")
private int iterations;
public static final String PARAM_CUTOFF = "cutoff";
@ConfigurationParameter(name = PARAM_CUTOFF, mandatory = true, defaultValue = "0")
private int cutoff;
/**
* @see NameFinderME#DEFAULT_BEAM_SIZE
*/
public static final String PARAM_BEAMSIZE = "beamSize";
@ConfigurationParameter(name = PARAM_BEAMSIZE, mandatory = true, defaultValue = "3")
private int beamSize;
public static final String PARAM_FEATURE_GEN = "featureGen";
@ConfigurationParameter(name = PARAM_FEATURE_GEN, mandatory = false)
private File featureGen;
public static final String PARAM_SEQUENCE_ENCODING = "sequenceEncoding";
@ConfigurationParameter(name = PARAM_SEQUENCE_ENCODING, mandatory = true, defaultValue="BILOU")
private SequenceEncoding sequenceEncoding;
private CasNameSampleStream stream;
private ExecutorService executor = Executors.newSingleThreadExecutor();
private Future<TokenNameFinderModel> future;
@Override
public void initialize(UimaContext aContext)
throws ResourceInitializationException
{
super.initialize(aContext);
stream = new CasNameSampleStream();
TrainingParameters params = new TrainingParameters();
params.put(TrainingParameters.ALGORITHM_PARAM, algorithm);
// params.put(TrainingParameters.TRAINER_TYPE_PARAM,
// TrainerFactory.getTrainerType(params.getSettings()).name());
params.put(TrainingParameters.ITERATIONS_PARAM, Integer.toString(iterations));
params.put(TrainingParameters.CUTOFF_PARAM, Integer.toString(cutoff));
params.put(BeamSearch.BEAM_SIZE_PARAMETER, Integer.toString(beamSize));
byte featureGenCfg[] = loadFeatureGen(featureGen);
Callable<TokenNameFinderModel> trainTask = () -> {
try {
return NameFinderME.train(language, null, stream, params,
new TokenNameFinderFactory(featureGenCfg,
Collections.<String, Object> emptyMap(),
sequenceEncoding.getCodec()));
}
catch (Throwable e) {
stream.close();
throw e;
}
};
future = executor.submit(trainTask);
}
@Override
public void process(JCas aJCas)
throws AnalysisEngineProcessException
{
if (!future.isCancelled()) {
stream.send(aJCas);
}
}
@Override
public void collectionProcessComplete()
throws AnalysisEngineProcessException
{
try {
stream.close();
}
catch (IOException e) {
throw new AnalysisEngineProcessException(e);
}
TokenNameFinderModel model;
try {
model = future.get();
}
catch (InterruptedException | ExecutionException e) {
throw new AnalysisEngineProcessException(e);
}
try (OutputStream out = new FileOutputStream(targetLocation)) {
model.serialize(out);
}
catch (IOException e) {
throw new AnalysisEngineProcessException(e);
}
}
private byte[] loadFeatureGen(File aFile)
throws ResourceInitializationException
{
byte featureGenCfg[] = null;
if (aFile != null) {
try (InputStream in = new FileInputStream(aFile)) {
featureGenCfg = IOUtils.toByteArray(in);
}
catch (IOException e) {
throw new ResourceInitializationException(e);
}
}
return featureGenCfg;
}
}