/**
* Copyright 2007-2014
* Ubiquitous Knowledge Processing (UKP) Lab
* Technische Universität Darmstadt
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package de.tudarmstadt.ukp.dkpro.core.stanfordnlp;
import static org.apache.uima.fit.util.JCasUtil.indexCovered;
import static org.apache.uima.fit.util.JCasUtil.select;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.Map;
import java.util.Properties;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.component.JCasConsumer_ImplBase;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.jcas.JCas;
import org.apache.uima.resource.ResourceInitializationException;
import de.tudarmstadt.ukp.dkpro.core.api.parameter.ComponentParameters;
import de.tudarmstadt.ukp.dkpro.core.api.resources.ResourceUtils;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token;
import edu.stanford.nlp.tagger.maxent.MaxentTagger;
/**
* Train a POS tagging model for the Stanford POS tagger.
*/
public class StanfordPosTaggerTrainer
extends JCasConsumer_ImplBase
{
public static final String PARAM_TARGET_LOCATION = ComponentParameters.PARAM_TARGET_LOCATION;
@ConfigurationParameter(name = PARAM_TARGET_LOCATION, mandatory = true)
private File targetLocation;
/**
* Training file containing the parameters. The <code>trainFile</code>, <code>model</code> and
* <code>encoding</code> parameters in this file are ignored/overwritten. In the <code>arch</code>
* parameter, the string <code>${distsimCluster}</code> is replaced with the path to the cluster
* files if {@link #PARAM_CLUSTER_FILE} is specified.
*/
public static final String PARAM_PARAMETER_FILE = "trainFile";
@ConfigurationParameter(name = PARAM_PARAMETER_FILE, mandatory = true)
private File parameterFile;
/**
* Distsim cluster files.
*/
public static final String PARAM_CLUSTER_FILE = "clusterFile";
@ConfigurationParameter(name = PARAM_CLUSTER_FILE, mandatory = false)
private File clusterFile;
private boolean clusterFilesTemporary;
private File tempData;
private PrintWriter out;
@Override
public void initialize(UimaContext aContext)
throws ResourceInitializationException
{
super.initialize(aContext);
try {
String p = clusterFile.getAbsolutePath();
if (p.contains("(") || p.contains(")") || p.contains(",")) {
// The Stanford POS tagger trainer does not support these characters in the cluster
// files path. If we have those, try to copy the clusters somewhere save before
// training. See: https://github.com/stanfordnlp/CoreNLP/issues/255
File tempClusterFile = ResourceUtils.getUrlAsFile(clusterFile.toURI().toURL(), true);
FileUtils.copyFile(clusterFile, tempClusterFile);
clusterFile = tempClusterFile;
clusterFilesTemporary = true;
}
else {
clusterFilesTemporary = false;
}
}
catch (IOException e) {
throw new ResourceInitializationException(e);
}
}
@Override
public void process(JCas aJCas)
throws AnalysisEngineProcessException
{
if (tempData == null) {
try {
tempData = File.createTempFile("dkpro-stanford-pos-trainer", ".tsv");
out = new PrintWriter(
new OutputStreamWriter(new FileOutputStream(tempData), StandardCharsets.UTF_8));
}
catch (IOException e) {
throw new AnalysisEngineProcessException(e);
}
}
Map<Sentence, Collection<Token>> index = indexCovered(aJCas, Sentence.class, Token.class);
for (Sentence sentence : select(aJCas, Sentence.class)) {
Collection<Token> tokens = index.get(sentence);
for (Token token : tokens) {
out.printf("%s\t%s%n", token.getCoveredText(), token.getPos().getPosValue());
}
out.println();
}
}
@Override
public void collectionProcessComplete()
throws AnalysisEngineProcessException
{
if (out != null) {
IOUtils.closeQuietly(out);
}
// Load user-provided configuration
Properties props = new Properties();
try (InputStream is = new FileInputStream(parameterFile)) {
props.load(is);
}
catch (IOException e) {
throw new AnalysisEngineProcessException(e);
}
// Add/replace training file information
props.setProperty("trainFile",
"format=TSV,wordColumn=0,tagColumn=1," + tempData.getAbsolutePath());
props.setProperty("model", targetLocation.getAbsolutePath());
props.setProperty("encoding", "UTF-8");
if (clusterFile != null) {
String arch = props.getProperty("arch");
arch = arch.replaceAll("\\$\\{distsimCluster\\}", clusterFile.getAbsolutePath());
props.setProperty("arch", arch);
}
File tempConfig = null;
try {
// Write to a temporary location
tempConfig = File.createTempFile("dkpro-stanford-pos-trainer", ".props");
try (OutputStream os = new FileOutputStream(tempConfig)) {
props.store(os, null);
}
// Train
MaxentTagger.main(new String[] {"-props", tempConfig.getAbsolutePath()});
}
catch (Exception e) {
throw new AnalysisEngineProcessException(e);
}
finally {
// Clean up temporary parameter file
if (tempConfig != null) {
tempConfig.delete();
}
}
}
@Override
public void destroy()
{
super.destroy();
// Clean up temporary data file
if (tempData != null) {
tempData.delete();
}
if (clusterFilesTemporary) {
clusterFile.delete();
}
}
}