/** * Copyright 2007-2014 * Ubiquitous Knowledge Processing (UKP) Lab * Technische Universität Darmstadt * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package de.tudarmstadt.ukp.dkpro.core.stanfordnlp; import static org.apache.uima.fit.util.JCasUtil.select; import static org.apache.uima.fit.util.JCasUtil.selectCovered; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStreamWriter; import java.io.PrintWriter; import java.nio.charset.StandardCharsets; import java.util.Collection; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Properties; import org.apache.commons.io.IOUtils; import org.apache.uima.UimaContext; import org.apache.uima.analysis_engine.AnalysisEngineProcessException; import org.apache.uima.cas.Feature; import org.apache.uima.cas.Type; import org.apache.uima.fit.component.JCasConsumer_ImplBase; import org.apache.uima.fit.descriptor.ConfigurationParameter; import org.apache.uima.fit.util.JCasUtil; import org.apache.uima.jcas.JCas; import org.apache.uima.resource.ResourceInitializationException; import de.tudarmstadt.ukp.dkpro.core.api.io.IobEncoder; import de.tudarmstadt.ukp.dkpro.core.api.ner.type.NamedEntity; import de.tudarmstadt.ukp.dkpro.core.api.parameter.ComponentParameters; import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence; import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token; import edu.stanford.nlp.ie.crf.CRFClassifier; import edu.stanford.nlp.ling.CoreLabel; import edu.stanford.nlp.sequences.SeqClassifierFlags; /** * Train a NER model for Stanford CoreNLP Named Entity Recognizer. */ public class StanfordNamedEntityRecognizerTrainer extends JCasConsumer_ImplBase { /** * Location of the target model file. */ public static final String PARAM_TARGET_LOCATION = ComponentParameters.PARAM_TARGET_LOCATION; @ConfigurationParameter(name = PARAM_TARGET_LOCATION, mandatory = true) private File targetLocation; /** * Training file containing the parameters. The <code>trainFile</code> or * <code>trainFileList</code> and <code>serializeTo</code> parameters in this file are * ignored/overridden. */ public static final String PARAM_PROPERTIES_LOCATION = "propertiesFile"; @ConfigurationParameter(name = PARAM_PROPERTIES_LOCATION, mandatory = false) private File propertiesFile; /* * Label set to use for training. Options: IOB1, IOB2, IOE1, IOE2, SBIEO, IO, BIO, BILOU, * noprefix * * Default: noprefix */ public static final String PARAM_LABEL_SET = "entitySubClassification"; @ConfigurationParameter(name = PARAM_LABEL_SET, mandatory = false, defaultValue = "noprefix") private String entitySubClassification; /** * Flag to keep the label set specified by PARAM_LABEL_SET. If set to false, representation is * mapped to IOB1 on output. Default: true */ public static final String PARAM_RETAIN_CLASS = "retainClassification"; @ConfigurationParameter(name = PARAM_RETAIN_CLASS, mandatory = false, defaultValue = "true") private boolean retainClassification; private File tempData; private PrintWriter out; @Override public void initialize(UimaContext aContext) throws ResourceInitializationException { super.initialize(aContext); } @Override public void process(JCas aJCas) throws AnalysisEngineProcessException { if (tempData == null) { try { tempData = File.createTempFile("dkpro-stanford-ner-trainer", ".tsv"); getLogger() .info(String.format("Created temp file: %s", tempData.getAbsolutePath())); out = new PrintWriter(new OutputStreamWriter(new FileOutputStream(tempData), StandardCharsets.UTF_8)); } catch (IOException e) { throw new AnalysisEngineProcessException(e); } } convert(aJCas, out); getLogger().info("Conversion process complete."); } /* * Taken from Conll2003Writer and modified for the task at hand. */ private void convert(JCas aJCas, PrintWriter aOut) { Type neType = JCasUtil.getType(aJCas, NamedEntity.class); Feature neValue = neType.getFeatureByBaseName("value"); // Named Entities IobEncoder neEncoder = new IobEncoder(aJCas.getCas(), neType, neValue, false); Map<Sentence, Collection<NamedEntity>> idx = JCasUtil.indexCovered(aJCas, Sentence.class, NamedEntity.class); Collection<NamedEntity> coveredNEs; for (Sentence sentence : select(aJCas, Sentence.class)) { coveredNEs = idx.get(sentence); /* * don't include sentence in temp file that contains no annotations * * (saves memory for training) */ if (coveredNEs.isEmpty()) { continue; } HashMap<Token, Row> ctokens = new LinkedHashMap<>(); // Tokens List<Token> tokens = selectCovered(Token.class, sentence); for (Token token : tokens) { Row row = new Row(); row.token = token; row.ne = neEncoder.encode(token); ctokens.put(row.token, row); } // Write sentence in column format for (Row row : ctokens.values()) { aOut.printf("%s\t%s%n", row.token.getCoveredText(), row.ne); } aOut.println(); } } private static final class Row { Token token; String ne; } @Override public void collectionProcessComplete() throws AnalysisEngineProcessException { IOUtils.closeQuietly(out); // Load user-provided configuration Properties props = new Properties(); try (InputStream is = new FileInputStream(propertiesFile)) { props.load(is); } catch (IOException e) { throw new AnalysisEngineProcessException(e); } // Add/replace training file information props.setProperty("serializeTo", targetLocation.getAbsolutePath()); // set training data info props.setProperty("trainFile", tempData.getAbsolutePath()); props.setProperty("map", "word=0,answer=1"); SeqClassifierFlags flags = new SeqClassifierFlags(props); // label set flags.entitySubclassification = entitySubClassification; // if representation should be kept flags.retainEntitySubclassification = retainClassification; // need to use this reader because the other ones don't recognize the previous settings // about the label set flags.readerAndWriter = "edu.stanford.nlp.sequences.CoNLLDocumentReaderAndWriter"; // Train CRFClassifier<CoreLabel> crf = new CRFClassifier<>(flags); getLogger().info("Starting to train..."); crf.train(); try { getLogger().info(String.format("Serializing classifier to target location: %s", targetLocation.getCanonicalPath())); crf.serializeClassifier(targetLocation.getAbsolutePath()); } catch (IOException e) { throw new AnalysisEngineProcessException(e); } } @Override public void destroy() { super.destroy(); // Clean up temporary data file if (tempData != null) { tempData.delete(); } } }