/*
* Copyright 2016
* Ubiquitous Knowledge Processing (UKP) Lab
* Technische Universität Darmstadt
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package de.tudarmstadt.ukp.dkpro.core.opennlp;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.net.URL;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.regex.Pattern;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.component.JCasConsumer_ImplBase;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;
import de.tudarmstadt.ukp.dkpro.core.api.parameter.ComponentParameters;
import de.tudarmstadt.ukp.dkpro.core.api.resources.ResourceUtils;
import de.tudarmstadt.ukp.dkpro.core.opennlp.internal.CasTokenSampleStream;
import opennlp.tools.dictionary.Dictionary;
import opennlp.tools.ml.EventTrainer;
import opennlp.tools.ml.maxent.GISTrainer;
import opennlp.tools.tokenize.TokenizerFactory;
import opennlp.tools.tokenize.TokenizerME;
import opennlp.tools.tokenize.TokenizerModel;
import opennlp.tools.tokenize.lang.Factory;
import opennlp.tools.util.TrainingParameters;
/**
* Train a tokenizer model for OpenNLP.
*/
public class OpenNlpTokenTrainer
extends JCasConsumer_ImplBase
{
public static final String PARAM_LANGUAGE = ComponentParameters.PARAM_LANGUAGE;
@ConfigurationParameter(name = PARAM_LANGUAGE, mandatory = true)
private String language;
public static final String PARAM_TARGET_LOCATION = ComponentParameters.PARAM_TARGET_LOCATION;
@ConfigurationParameter(name = PARAM_TARGET_LOCATION, mandatory = true)
private File targetLocation;
public static final String PARAM_ALGORITHM = "algorithm";
@ConfigurationParameter(name = PARAM_ALGORITHM, mandatory = true, defaultValue = GISTrainer.MAXENT_VALUE)
private String algorithm;
public static final String PARAM_TRAINER_TYPE = "trainerType";
@ConfigurationParameter(name = PARAM_TRAINER_TYPE, mandatory = true, defaultValue = EventTrainer.EVENT_VALUE)
private String trainerType;
public static final String PARAM_ITERATIONS = "iterations";
@ConfigurationParameter(name = PARAM_ITERATIONS, mandatory = true, defaultValue = "100")
private int iterations;
public static final String PARAM_CUTOFF = "cutoff";
@ConfigurationParameter(name = PARAM_CUTOFF, mandatory = true, defaultValue = "5")
private int cutoff;
public static final String PARAM_USE_ALPHANUMERIC_OPTIMIZATION = "useAlphaNumericOptimization";
@ConfigurationParameter(name = PARAM_USE_ALPHANUMERIC_OPTIMIZATION, mandatory = true, defaultValue = "true")
private boolean useAlphaNumericOptimization;
public static final String PARAM_ALPHA_NUMERIC_PATTERN = "alphaNumericPattern";
@ConfigurationParameter(name = PARAM_ALPHA_NUMERIC_PATTERN, mandatory = false, defaultValue = Factory.DEFAULT_ALPHANUMERIC)
private Pattern alphaNumericPattern;
public static final String PARAM_ABBREVIATION_DICTIONARY_LOCATION = "abbreviationDictionaryLocation";
@ConfigurationParameter(name = PARAM_ABBREVIATION_DICTIONARY_LOCATION, mandatory = false)
private String abbreviationDictionaryLocation;
public static final String PARAM_ABBREVIATION_DICTIONARY_ENCODING = "abbreviationDictionaryEncoding";
@ConfigurationParameter(name = PARAM_ABBREVIATION_DICTIONARY_ENCODING, mandatory = true, defaultValue = "UTF-8")
private String abbreviationDictionaryEncoding;
private CasTokenSampleStream stream;
private Dictionary abbreviationDictionary;
private ExecutorService executor = Executors.newSingleThreadExecutor();
private Future<TokenizerModel> future;
@Override
public void initialize(UimaContext aContext)
throws ResourceInitializationException
{
super.initialize(aContext);
stream = new CasTokenSampleStream();
TrainingParameters params = new TrainingParameters();
params.put(TrainingParameters.ALGORITHM_PARAM, algorithm);
params.put(TrainingParameters.TRAINER_TYPE_PARAM, trainerType);
params.put(TrainingParameters.ITERATIONS_PARAM, Integer.toString(iterations));
params.put(TrainingParameters.CUTOFF_PARAM, Integer.toString(cutoff));
if (abbreviationDictionaryLocation != null) {
try {
URL abbrevUrl = ResourceUtils.resolveLocation(abbreviationDictionaryLocation,
aContext);
try (InputStream is = abbrevUrl.openStream()) {
abbreviationDictionary = Dictionary.parseOneEntryPerLine(
new InputStreamReader(is, abbreviationDictionaryEncoding));
}
}
catch (IOException e) {
throw new ResourceInitializationException(e);
}
}
else {
abbreviationDictionary = null;
}
Callable<TokenizerModel> trainTask = () -> {
try {
TokenizerFactory factory = new TokenizerFactory(language,
abbreviationDictionary, useAlphaNumericOptimization, alphaNumericPattern);
return TokenizerME.train(stream, factory, params);
}
catch (Throwable e) {
stream.close();
throw e;
}
};
future = executor.submit(trainTask);
}
@Override
public void process(JCas aJCas)
throws AnalysisEngineProcessException
{
if (!future.isCancelled()) {
stream.send(aJCas);
}
}
@Override
public void collectionProcessComplete()
throws AnalysisEngineProcessException
{
try {
stream.close();
}
catch (IOException e) {
throw new AnalysisEngineProcessException(e);
}
TokenizerModel model;
try {
model = future.get();
}
catch (InterruptedException | ExecutionException e) {
throw new AnalysisEngineProcessException(e);
}
try (OutputStream out = new FileOutputStream(targetLocation)) {
model.serialize(out);
}
catch (IOException e) {
throw new AnalysisEngineProcessException(e);
}
}
}