/* * Copyright 2014 * Ubiquitous Knowledge Processing (UKP) Lab * Technische Universität Darmstadt * <p> * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * <p> * http://www.apache.org/licenses/LICENSE-2.0 * <p> * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package de.tudarmstadt.ukp.dkpro.core.mallet.lda; import cc.mallet.topics.ParallelTopicModel; import cc.mallet.types.Instance; import de.tudarmstadt.ukp.dkpro.core.mallet.MalletModelTrainer; import org.apache.uima.analysis_engine.AnalysisEngineProcessException; import org.apache.uima.fit.descriptor.ConfigurationParameter; import java.io.File; import java.io.IOException; /** * Estimate an LDA topic model using Mallet and write it to a file. It stores all incoming CAS' to * Mallet {@link Instance}s before estimating the model, using a {@link ParallelTopicModel}. * <p> * Set {@link #PARAM_TOKEN_FEATURE_PATH} to define what is considered as a token (Tokens, Lemmas, etc.). * <p> * Set {@link #PARAM_COVERING_ANNOTATION_TYPE} to define what is considered a document (sentences, paragraphs, etc.). */ public class MalletLdaTopicModelTrainer extends MalletModelTrainer { /** * The number of topics to estimate. */ public static final String PARAM_N_TOPICS = "nTopics"; @ConfigurationParameter(name = PARAM_N_TOPICS, mandatory = true, defaultValue = "10") private int nTopics; /** * The number of iterations during model estimation. Default: 1000. */ public static final String PARAM_N_ITERATIONS = "nIterations"; @ConfigurationParameter(name = PARAM_N_ITERATIONS, mandatory = true, defaultValue = "1000") private int nIterations; /** * The number of iterations before hyper-parameter optimization begins. Default: 100 */ public static final String PARAM_BURNIN_PERIOD = "burninPeriod"; @ConfigurationParameter(name = PARAM_BURNIN_PERIOD, mandatory = true, defaultValue = "100") private int burninPeriod; /** * Interval for optimizing Dirichlet hyper-parameters. Default: 50 */ public static final String PARAM_OPTIMIZE_INTERVAL = "optimizeInterval"; @ConfigurationParameter(name = PARAM_OPTIMIZE_INTERVAL, mandatory = true, defaultValue = "50") private int optimizeInterval; /** * Set random seed. If set to -1 (default), uses random generator. */ public static final String PARAM_RANDOM_SEED = "randomSeed"; @ConfigurationParameter(name = PARAM_RANDOM_SEED, mandatory = true, defaultValue = "-1") private int randomSeed; /** * Define how frequently a serialized model is saved to disk during estimation. Default: 0 (only save when * estimation is done). */ public static final String PARAM_SAVE_INTERVAL = "saveInterval"; @ConfigurationParameter(name = PARAM_SAVE_INTERVAL, mandatory = true, defaultValue = "0") private int saveInterval; /** * Use a symmetric alpha value during model estimation? Default: false. */ public static final String PARAM_USE_SYMMETRIC_ALPHA = "useSymmetricAlpha"; @ConfigurationParameter(name = PARAM_USE_SYMMETRIC_ALPHA, mandatory = true, defaultValue = "false") private boolean useSymmetricAlpha; /** * The interval in which to display the estimated topics. Default: 50. */ public static final String PARAM_DISPLAY_INTERVAL = "displayInterval"; @ConfigurationParameter(name = PARAM_DISPLAY_INTERVAL, mandatory = true, defaultValue = "50") private int displayInterval; /** * The number of top words to display during estimation. Default: 7. */ public static final String PARAM_DISPLAY_N_TOPIC_WORDS = "displayNTopicWords"; @ConfigurationParameter(name = PARAM_DISPLAY_N_TOPIC_WORDS, mandatory = true, defaultValue = "7") private int displayNTopicWords; /** * The sum of alphas over all topics. Default: 1.0. * <p> * Another recommended value is 50 / T (number of topics). */ public static final String PARAM_ALPHA_SUM = "alphaSum"; @ConfigurationParameter(name = PARAM_ALPHA_SUM, mandatory = true, defaultValue = "1.0f") private float alphaSum; /** * Beta for a single dimension of the Dirichlet prior. Default: 0.01. */ public static final String PARAM_BETA = "beta"; @ConfigurationParameter(name = PARAM_BETA, mandatory = true, defaultValue = "0.01f") private float beta; @Override public void collectionProcessComplete() throws AnalysisEngineProcessException { try { ParallelTopicModel model = new ParallelTopicModel(nTopics, alphaSum, beta); model.addInstances(getInstanceList()); model.setNumThreads(getNumThreads()); model.setNumIterations(nIterations); model.setBurninPeriod(burninPeriod); model.setOptimizeInterval(optimizeInterval); model.setRandomSeed(randomSeed); model.setSaveSerializedModel(saveInterval, getTargetLocation()); model.setSymmetricAlpha(useSymmetricAlpha); model.setTopicDisplay(displayInterval, displayNTopicWords); model.estimate(); getLogger().info("Writing model to " + getTargetLocation()); File targetFile = new File(getTargetLocation()); if (targetFile.getParentFile() != null) { targetFile.getParentFile().mkdirs(); } model.write(targetFile); } catch (IOException | SecurityException e) { throw new AnalysisEngineProcessException(e); } } }