/* * Copyright 2014 * Ubiquitous Knowledge Processing (UKP) Lab * Technische Universität Darmstadt * <p> * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * <p> * http://www.apache.org/licenses/LICENSE-2.0 * <p> * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package de.tudarmstadt.ukp.dkpro.core.mallet.lda; import cc.mallet.pipe.Pipe; import cc.mallet.pipe.TokenSequence2FeatureSequence; import cc.mallet.topics.ParallelTopicModel; import cc.mallet.topics.TopicInferencer; import cc.mallet.types.Instance; import cc.mallet.types.TokenSequence; import de.tudarmstadt.ukp.dkpro.core.api.featurepath.FeaturePathException; import de.tudarmstadt.ukp.dkpro.core.api.io.sequencegenerator.PhraseSequenceGenerator; import de.tudarmstadt.ukp.dkpro.core.api.io.sequencegenerator.StringSequenceGenerator; import de.tudarmstadt.ukp.dkpro.core.api.metadata.type.DocumentMetaData; import de.tudarmstadt.ukp.dkpro.core.api.parameter.ComponentParameters; import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token; import de.tudarmstadt.ukp.dkpro.core.mallet.MalletModelTrainer; import de.tudarmstadt.ukp.dkpro.core.mallet.type.TopicDistribution; import org.apache.commons.lang3.ArrayUtils; import org.apache.uima.UimaContext; import org.apache.uima.analysis_engine.AnalysisEngineProcessException; import org.apache.uima.fit.component.JCasAnnotator_ImplBase; import org.apache.uima.fit.descriptor.ConfigurationParameter; import org.apache.uima.fit.descriptor.TypeCapability; import org.apache.uima.jcas.JCas; import org.apache.uima.jcas.cas.DoubleArray; import org.apache.uima.jcas.cas.IntegerArray; import org.apache.uima.resource.ResourceInitializationException; import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; /** * Infers the topic distribution over documents using a Mallet {@link ParallelTopicModel}. */ @TypeCapability( inputs = { "de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token" }, outputs = { "de.tudarmstadt.ukp.dkpro.core.mallet.type.TopicDistribution" } ) public class MalletLdaTopicModelInferencer extends JCasAnnotator_ImplBase { private static final String NONE_LABEL = "X"; public final static String PARAM_MODEL_LOCATION = ComponentParameters.PARAM_MODEL_LOCATION; @ConfigurationParameter(name = PARAM_MODEL_LOCATION, mandatory = true) private File modelLocation; /** * The annotation type to use as tokens. Default: {@link Token} */ public final static String PARAM_TYPE_NAME = "typeName"; @ConfigurationParameter(name = PARAM_TYPE_NAME, mandatory = true, defaultValue = "de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token") private String typeName; /** * The number of iterations during inference. Default: 100. */ public final static String PARAM_N_ITERATIONS = "nIterations"; @ConfigurationParameter(name = PARAM_N_ITERATIONS, mandatory = true, defaultValue = "100") private int nIterations; /** * The number of iterations before hyperparameter optimization begins. Default: 1 */ public final static String PARAM_BURN_IN = "burnIn"; @ConfigurationParameter(name = PARAM_BURN_IN, mandatory = true, defaultValue = "1") private int burnIn; public final static String PARAM_THINNING = "thinning"; @ConfigurationParameter(name = PARAM_THINNING, mandatory = true, defaultValue = "5") private int thinning; /** * Minimum topic proportion for the document-topic assignment. */ public final static String PARAM_MIN_TOPIC_PROB = "minTopicProb"; @ConfigurationParameter(name = PARAM_MIN_TOPIC_PROB, mandatory = true, defaultValue = "0.2") private double minTopicProb; /** * Maximum number of topics to assign. If not set (or <= 0), the number of topics in the * model divided by 10 is set. */ public final static String PARAM_MAX_TOPIC_ASSIGNMENTS = "maxTopicAssignments"; @ConfigurationParameter(name = PARAM_MAX_TOPIC_ASSIGNMENTS, mandatory = true, defaultValue = "0") private int maxTopicAssignments; /** * The annotation type to use for the model. Default: {@code de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token}. * For lemmas, use {@code de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token/lemma/value} */ public static final String PARAM_TOKEN_FEATURE_PATH = MalletModelTrainer.PARAM_TOKEN_FEATURE_PATH; @ConfigurationParameter(name = PARAM_TOKEN_FEATURE_PATH, mandatory = true, defaultValue = "de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token") private String tokenFeaturePath; /** * Ignore tokens (or lemmas, respectively) that are shorter than the given value. Default: 3. */ public static final String PARAM_MIN_TOKEN_LENGTH = "minTokenLength"; @ConfigurationParameter(name = PARAM_MIN_TOKEN_LENGTH, mandatory = true, defaultValue = "3") private int minTokenLength; /** * If set to true (default: false), all tokens are lowercased. */ public static final String PARAM_LOWERCASE = "lowercase"; @ConfigurationParameter(name = PARAM_LOWERCASE, mandatory = true, defaultValue = "false") private boolean lowercase; private TopicInferencer inferencer; private Pipe malletPipe; private StringSequenceGenerator sequenceGenerator; @Override public void initialize(UimaContext context) throws ResourceInitializationException { super.initialize(context); ParallelTopicModel model; try { getLogger().info("Loading model file " + modelLocation); model = ParallelTopicModel.read(modelLocation); if (maxTopicAssignments <= 0) { maxTopicAssignments = model.getNumTopics() / 10; } } catch (Exception e) { throw new ResourceInitializationException(e); } getLogger().info("Model loaded."); inferencer = model.getInferencer(); malletPipe = new TokenSequence2FeatureSequence(model.getAlphabet()); try { sequenceGenerator = new PhraseSequenceGenerator.Builder() .featurePath(tokenFeaturePath) .minTokenLength(minTokenLength) .lowercase(lowercase) .buildStringSequenceGenerator(); } catch (IOException e) { throw new ResourceInitializationException(e); } } @Override public void process(JCas aJCas) throws AnalysisEngineProcessException { try { List<String[]> tokenSequences = sequenceGenerator.tokenSequences(aJCas); if (tokenSequences.isEmpty()) { getLogger().warn("Empty document."); } else { DocumentMetaData metadata = DocumentMetaData.get(aJCas); /* create Mallet Instance */ TokenSequence ts = new TokenSequence(tokenSequences.get(0)); Instance instance = new Instance(ts, NONE_LABEL, metadata.getDocumentId(), metadata.getDocumentUri()); /* infer topic distribution across document */ TopicDistribution topicDistributionAnnotation = new TopicDistribution(aJCas); double[] topicDistribution = inferencer.getSampledDistribution( malletPipe.instanceFrom(instance), nIterations, thinning, burnIn); /* convert data type (Mallet output -> UIMA double array) */ DoubleArray da = new DoubleArray(aJCas, topicDistribution.length); da.copyFromArray(topicDistribution, 0, 0, topicDistribution.length); topicDistributionAnnotation.setTopicProportions(da); /* assign topics to document according to topic distribution */ int[] assignedTopicIndexes = assignTopics(topicDistribution); IntegerArray topicIndexes = new IntegerArray(aJCas, assignedTopicIndexes.length); topicIndexes.copyFromArray(assignedTopicIndexes, 0, 0, assignedTopicIndexes.length); topicDistributionAnnotation.setTopicAssignment(topicIndexes); aJCas.addFsToIndexes(topicDistributionAnnotation); } } catch (FeaturePathException e) { throw new AnalysisEngineProcessException(e); } } /** * Assign topics according to the following formula: * <p> * Topic proportion must be at least the maximum topic's proportion divided by the maximum * number of topics to be assigned. In addition, the topic proportion must not lie under the * minTopicProb. If more topics comply with these criteria, only retain the n * (maxTopicAssignments) largest values. * * @param topicDistribution a double array containing the document's topic proportions * @return an array of integers pointing to the topics assigned to the document * @deprecated this method should be removed at some point because assignment / topic tagging * should be done in a dedicated step (module). */ // TODO: should return a boolean[] of the same size as topicDistribution // TODO: should probably be moved to a dedicated module because assignments (topic tagging) // should not be done at inference level @Deprecated private int[] assignTopics(final double[] topicDistribution) { /* * threshold is the largest value divided by the maximum number of topics or the fixed * number set as minTopicProb parameter. */ double threshold = Math.max(Collections.max( Arrays.asList(ArrayUtils.toObject(topicDistribution))).doubleValue() / maxTopicAssignments, minTopicProb); /* * assign indexes for values that are above threshold */ List<Integer> indexes = new ArrayList<>(topicDistribution.length); for (int i = 0; i < topicDistribution.length; i++) { if (topicDistribution[i] >= threshold) { indexes.add(i); } } /* * Reduce assignments to maximum number of allowed assignments. */ if (indexes.size() > maxTopicAssignments) { /* sort index list by corresponding values */ Collections.sort(indexes, (aO1, aO2) -> Double.compare(topicDistribution[aO1], topicDistribution[aO2])); while (indexes.size() > maxTopicAssignments) { indexes.remove(0); } } return ArrayUtils.toPrimitive(indexes.toArray(new Integer[indexes.size()])); } }