/* * Copyright 2016 * Ubiquitous Knowledge Processing (UKP) Lab * Technische Universität Darmstadt * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package de.tudarmstadt.ukp.dkpro.core.opennlp; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.Collections; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import org.apache.commons.io.IOUtils; import org.apache.uima.UimaContext; import org.apache.uima.analysis_engine.AnalysisEngineProcessException; import org.apache.uima.fit.component.JCasConsumer_ImplBase; import org.apache.uima.fit.descriptor.ConfigurationParameter; import org.apache.uima.jcas.JCas; import org.apache.uima.resource.ResourceInitializationException; import de.tudarmstadt.ukp.dkpro.core.api.parameter.ComponentParameters; import de.tudarmstadt.ukp.dkpro.core.opennlp.internal.CasNameSampleStream; import opennlp.tools.ml.BeamSearch; import opennlp.tools.ml.EventTrainer; import opennlp.tools.ml.maxent.GISTrainer; import opennlp.tools.ml.maxent.quasinewton.QNTrainer; import opennlp.tools.ml.perceptron.PerceptronTrainer; import opennlp.tools.ml.perceptron.SimplePerceptronSequenceTrainer; import opennlp.tools.namefind.BilouCodec; import opennlp.tools.namefind.BioCodec; import opennlp.tools.namefind.NameFinderME; import opennlp.tools.namefind.TokenNameFinderFactory; import opennlp.tools.namefind.TokenNameFinderModel; import opennlp.tools.util.SequenceCodec; import opennlp.tools.util.TrainingParameters; /** * Train a named entity recognizer model for OpenNLP. */ public class OpenNlpNamedEntityRecognizerTrainer extends JCasConsumer_ImplBase { public static enum SequenceEncoding { BIO(BioCodec.class), BILOU(BilouCodec.class); private Class<? extends SequenceCodec<String>> codec; SequenceEncoding(Class<? extends SequenceCodec<String>> aCodec) { codec = aCodec; } private SequenceCodec<String> getCodec() { try { return codec.newInstance(); } catch (InstantiationException | IllegalAccessException e) { throw new IllegalStateException(e); } } } public static final String PARAM_LANGUAGE = ComponentParameters.PARAM_LANGUAGE; @ConfigurationParameter(name = PARAM_LANGUAGE, mandatory = true) private String language; public static final String PARAM_TARGET_LOCATION = ComponentParameters.PARAM_TARGET_LOCATION; @ConfigurationParameter(name = PARAM_TARGET_LOCATION, mandatory = true) private File targetLocation; /** * @see GISTrainer#MAXENT_VALUE * @see QNTrainer#MAXENT_QN_VALUE * @see PerceptronTrainer#PERCEPTRON_VALUE * @see SimplePerceptronSequenceTrainer#PERCEPTRON_SEQUENCE_VALUE */ public static final String PARAM_ALGORITHM = "algorithm"; @ConfigurationParameter(name = PARAM_ALGORITHM, mandatory = true, defaultValue = PerceptronTrainer.PERCEPTRON_VALUE) private String algorithm; public static final String PARAM_TRAINER_TYPE = "trainerType"; @ConfigurationParameter(name = PARAM_TRAINER_TYPE, mandatory = true, defaultValue = EventTrainer.EVENT_VALUE) private String trainerType; public static final String PARAM_ITERATIONS = "iterations"; @ConfigurationParameter(name = PARAM_ITERATIONS, mandatory = true, defaultValue = "300") private int iterations; public static final String PARAM_CUTOFF = "cutoff"; @ConfigurationParameter(name = PARAM_CUTOFF, mandatory = true, defaultValue = "0") private int cutoff; /** * @see NameFinderME#DEFAULT_BEAM_SIZE */ public static final String PARAM_BEAMSIZE = "beamSize"; @ConfigurationParameter(name = PARAM_BEAMSIZE, mandatory = true, defaultValue = "3") private int beamSize; public static final String PARAM_FEATURE_GEN = "featureGen"; @ConfigurationParameter(name = PARAM_FEATURE_GEN, mandatory = false) private File featureGen; public static final String PARAM_SEQUENCE_ENCODING = "sequenceEncoding"; @ConfigurationParameter(name = PARAM_SEQUENCE_ENCODING, mandatory = true, defaultValue="BILOU") private SequenceEncoding sequenceEncoding; private CasNameSampleStream stream; private ExecutorService executor = Executors.newSingleThreadExecutor(); private Future<TokenNameFinderModel> future; @Override public void initialize(UimaContext aContext) throws ResourceInitializationException { super.initialize(aContext); stream = new CasNameSampleStream(); TrainingParameters params = new TrainingParameters(); params.put(TrainingParameters.ALGORITHM_PARAM, algorithm); // params.put(TrainingParameters.TRAINER_TYPE_PARAM, // TrainerFactory.getTrainerType(params.getSettings()).name()); params.put(TrainingParameters.ITERATIONS_PARAM, Integer.toString(iterations)); params.put(TrainingParameters.CUTOFF_PARAM, Integer.toString(cutoff)); params.put(BeamSearch.BEAM_SIZE_PARAMETER, Integer.toString(beamSize)); byte featureGenCfg[] = loadFeatureGen(featureGen); Callable<TokenNameFinderModel> trainTask = () -> { try { return NameFinderME.train(language, null, stream, params, new TokenNameFinderFactory(featureGenCfg, Collections.<String, Object> emptyMap(), sequenceEncoding.getCodec())); } catch (Throwable e) { stream.close(); throw e; } }; future = executor.submit(trainTask); } @Override public void process(JCas aJCas) throws AnalysisEngineProcessException { if (!future.isCancelled()) { stream.send(aJCas); } } @Override public void collectionProcessComplete() throws AnalysisEngineProcessException { try { stream.close(); } catch (IOException e) { throw new AnalysisEngineProcessException(e); } TokenNameFinderModel model; try { model = future.get(); } catch (InterruptedException | ExecutionException e) { throw new AnalysisEngineProcessException(e); } try (OutputStream out = new FileOutputStream(targetLocation)) { model.serialize(out); } catch (IOException e) { throw new AnalysisEngineProcessException(e); } } private byte[] loadFeatureGen(File aFile) throws ResourceInitializationException { byte featureGenCfg[] = null; if (aFile != null) { try (InputStream in = new FileInputStream(aFile)) { featureGenCfg = IOUtils.toByteArray(in); } catch (IOException e) { throw new ResourceInitializationException(e); } } return featureGenCfg; } }