/* * Copyright 2016 * Ubiquitous Knowledge Processing (UKP) Lab * Technische Universität Darmstadt * <p> * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * <p> * http://www.apache.org/licenses/LICENSE-2.0 * <p> * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package de.tudarmstadt.ukp.dkpro.core.mallet; import cc.mallet.pipe.TokenSequence2FeatureSequence; import cc.mallet.types.Instance; import cc.mallet.types.InstanceList; import cc.mallet.types.TokenSequence; import de.tudarmstadt.ukp.dkpro.core.api.featurepath.FeaturePathException; import de.tudarmstadt.ukp.dkpro.core.api.io.JCasFileWriter_ImplBase; import de.tudarmstadt.ukp.dkpro.core.api.io.sequencegenerator.PhraseSequenceGenerator; import de.tudarmstadt.ukp.dkpro.core.api.io.sequencegenerator.StringSequenceGenerator; import de.tudarmstadt.ukp.dkpro.core.api.metadata.type.DocumentMetaData; import de.tudarmstadt.ukp.dkpro.core.api.parameter.ComponentParameters; import de.tudarmstadt.ukp.dkpro.core.mallet.lda.MalletLdaTopicModelTrainer; import de.tudarmstadt.ukp.dkpro.core.mallet.wordembeddings.MalletEmbeddingsTrainer; import org.apache.uima.UimaContext; import org.apache.uima.analysis_engine.AnalysisEngineProcessException; import org.apache.uima.fit.descriptor.ConfigurationParameter; import org.apache.uima.jcas.JCas; import org.apache.uima.resource.ResourceInitializationException; import java.io.IOException; import java.util.Locale; /** * This abstract class defines parameters and methods that are common for Mallet model estimators. * <p> * It creates a Mallet {@link InstanceList} from the input documents so that inheriting estimators * can create a model, typically implemented by overriding the {@link JCasFileWriter_ImplBase#collectionProcessComplete()} * method. * * @see MalletEmbeddingsTrainer * @see MalletLdaTopicModelTrainer * @since 1.9.0 */ public abstract class MalletModelTrainer extends JCasFileWriter_ImplBase { private static final String NONE_LABEL = "X"; // some label has to be set for Mallet instances private static final Locale LOCALE = Locale.US; /** * The annotation type to use as input tokens for the model estimation. * Default: {@code de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token}. * For lemmas, for instance, use {@code de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token/lemma/value} */ public static final String PARAM_TOKEN_FEATURE_PATH = "tokenFeaturePath"; @ConfigurationParameter(name = PARAM_TOKEN_FEATURE_PATH, mandatory = true, defaultValue = "de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token") private String tokenFeaturePath; /** * The number of threads to use during model estimation. * If not set, the number of threads is automatically set by {@link ComponentParameters#computeNumThreads(int)}. * <p> * Warning: do not set this to more than 1 when using very small (test) data sets on {@link MalletEmbeddingsTrainer}! * This might prevent the process from terminating. */ public static final String PARAM_NUM_THREADS = ComponentParameters.PARAM_NUM_THREADS; @ConfigurationParameter(name = PARAM_NUM_THREADS, mandatory = true, defaultValue = ComponentParameters.AUTO_NUM_THREADS) private int numThreads; /** * Ignore tokens (or any other annotation type, as specified by {@link #PARAM_TOKEN_FEATURE_PATH}) * that are shorter than the given value. Default: 3. */ public static final String PARAM_MIN_TOKEN_LENGTH = "minTokenLength"; @ConfigurationParameter(name = PARAM_MIN_TOKEN_LENGTH, mandatory = true, defaultValue = "3") private int minTokenLength; /** * If specified, the text contained in the given segmentation type annotations are fed as * separate units ("documents") to the topic model estimator e.g. * {@code de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.sentence}. Text that is not within * such annotations is ignored. * <p> * By default, the full text is used as a document. */ public static final String PARAM_COVERING_ANNOTATION_TYPE = "coveringAnnotationType"; @ConfigurationParameter(name = PARAM_COVERING_ANNOTATION_TYPE, mandatory = true, defaultValue = "") private String coveringAnnotationType; /** * If true (default: false), estimate character embeddings. {@link #PARAM_TOKEN_FEATURE_PATH} is * ignored. */ public static final String PARAM_USE_CHARACTERS = "useCharacters"; @ConfigurationParameter(name = PARAM_USE_CHARACTERS, mandatory = true, defaultValue = "false") private boolean useCharacters; /** * If set to true (default: false), all tokens are lowercased. */ public static final String PARAM_LOWERCASE = "lowercase"; @ConfigurationParameter(name = PARAM_LOWERCASE, mandatory = true, defaultValue = "false") private boolean lowercase; /** * The location of the stopwords file. */ public static final String PARAM_STOPWORDS_FILE = "paramStopwordsFile"; @ConfigurationParameter(name = PARAM_STOPWORDS_FILE, mandatory = true, defaultValue = "") private String stopwordsFile; /** * If set, stopwords found in the {@link #PARAM_STOPWORDS_FILE} location are not removed, but * replaced by the given string (e.g. {@code STOP}). */ public static final String PARAM_STOPWORDS_REPLACEMENT = "paramStopwordsReplacement"; @ConfigurationParameter(name = PARAM_STOPWORDS_REPLACEMENT, mandatory = true, defaultValue = "") private String stopwordsReplacement; /** * Filter out all tokens matching that regular expression. */ public static final String PARAM_FILTER_REGEX = "filterRegex"; @ConfigurationParameter(name = PARAM_FILTER_REGEX, mandatory = true, defaultValue = "") private String filterRegex; public static final String PARAM_FILTER_REGEX_REPLACEMENT = "filterRegexReplacement"; @ConfigurationParameter(name = PARAM_FILTER_REGEX_REPLACEMENT, mandatory = true, defaultValue = "") private String filterRegexReplacement; private InstanceList instanceList; // contains the Mallet instances private StringSequenceGenerator sequenceGenerator; @Override public void initialize(UimaContext context) throws ResourceInitializationException { super.initialize(context); if (getTargetLocation() == null) { throw new ResourceInitializationException( new IllegalArgumentException("No target location set!")); } // locale should be set to US to define the output format of the Mallet models (especially decimal numbers). Locale.setDefault(LOCALE); numThreads = ComponentParameters.computeNumThreads(numThreads); getLogger().info(String.format("Using %d threads.", numThreads)); /* Mallet instance list and token sequence generator */ instanceList = new InstanceList(new TokenSequence2FeatureSequence()); try { sequenceGenerator = new PhraseSequenceGenerator.Builder() .characters(useCharacters) .minTokenLength(minTokenLength) .stopwordsFile(stopwordsFile) .stopwordsReplacement(stopwordsReplacement) .featurePath(tokenFeaturePath) .filterRegex(filterRegex) .filterRegexReplacement(filterRegexReplacement) .coveringType(coveringAnnotationType) .lowercase(lowercase) .buildStringSequenceGenerator(); } catch (IOException e) { throw new ResourceInitializationException(e); } } @Override public void process(JCas aJCas) throws AnalysisEngineProcessException { DocumentMetaData metadata = DocumentMetaData.get(aJCas); try { /* retrieve token sequences and convert token sequences to instances */ sequenceGenerator.tokenSequences(aJCas).stream() .map(TokenSequence::new) .map(ts -> new Instance(ts, NONE_LABEL, metadata.getDocumentId(), metadata.getDocumentUri())) .forEach(instance -> instanceList.addThruPipe(instance)); } catch (FeaturePathException e) { throw new AnalysisEngineProcessException(e); } } protected int getNumThreads() { return numThreads; } public InstanceList getInstanceList() { return instanceList; } }