/* * Copyright 2014 * Ubiquitous Knowledge Processing (UKP) Lab * Technische Universität Darmstadt * <p> * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * <p> * http://www.apache.org/licenses/LICENSE-2.0 * <p> * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package de.tudarmstadt.ukp.dkpro.core.mallet.lda; import cc.mallet.topics.ParallelTopicModel; import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence; import de.tudarmstadt.ukp.dkpro.core.io.text.TextReader; import de.tudarmstadt.ukp.dkpro.core.testing.DkproTestContext; import de.tudarmstadt.ukp.dkpro.core.tokit.BreakIteratorSegmenter; import org.apache.uima.analysis_engine.AnalysisEngineDescription; import org.apache.uima.collection.CollectionReaderDescription; import org.apache.uima.fit.pipeline.SimplePipeline; import org.junit.Rule; import org.junit.Test; import java.io.File; import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngineDescription; import static org.apache.uima.fit.factory.CollectionReaderFactory.createReaderDescription; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; public class MalletLdaTopicModelTrainerTest { @Rule public DkproTestContext testContext = new DkproTestContext(); private static final String TXT_DIR = "src/test/resources/txt"; private static final String TXT_FILE_PATTERN = "[+]*.txt"; @Test public void testEstimator() throws Exception { File modelFile = new File(testContext.getTestOutputFolder(), "model"); // tag::example[] int nTopics = 10; int nIterations = 50; String language = "en"; CollectionReaderDescription reader = createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, TXT_DIR, TextReader.PARAM_PATTERNS, TXT_FILE_PATTERN, TextReader.PARAM_LANGUAGE, language); AnalysisEngineDescription segmenter = createEngineDescription(BreakIteratorSegmenter.class); AnalysisEngineDescription estimator = createEngineDescription( MalletLdaTopicModelTrainer.class, MalletLdaTopicModelTrainer.PARAM_TARGET_LOCATION, modelFile, MalletLdaTopicModelTrainer.PARAM_N_ITERATIONS, nIterations, MalletLdaTopicModelTrainer.PARAM_N_TOPICS, nTopics); SimplePipeline.runPipeline(reader, segmenter, estimator); // end::example[] assertTrue(modelFile.exists()); ParallelTopicModel model = ParallelTopicModel.read(modelFile); assertEquals(nTopics, model.getNumTopics()); } @Test public void testEstimatorSentence() throws Exception { File modelFile = new File(testContext.getTestOutputFolder(), "model"); int nTopics = 10; int nIterations = 50; String language = "en"; String entity = Sentence.class.getName(); CollectionReaderDescription reader = createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, TXT_DIR, TextReader.PARAM_PATTERNS, TXT_FILE_PATTERN, TextReader.PARAM_LANGUAGE, language); AnalysisEngineDescription segmenter = createEngineDescription(BreakIteratorSegmenter.class); AnalysisEngineDescription estimator = createEngineDescription( MalletLdaTopicModelTrainer.class, MalletLdaTopicModelTrainer.PARAM_TARGET_LOCATION, modelFile, MalletLdaTopicModelTrainer.PARAM_N_ITERATIONS, nIterations, MalletLdaTopicModelTrainer.PARAM_N_TOPICS, nTopics, MalletLdaTopicModelTrainer.PARAM_COVERING_ANNOTATION_TYPE, entity); SimplePipeline.runPipeline(reader, segmenter, estimator); assertTrue(modelFile.exists()); ParallelTopicModel model = ParallelTopicModel.read(modelFile); assertEquals(nTopics, model.getNumTopics()); } @Test public void testEstimatorAlphaBeta() throws Exception { File modelFile = new File(testContext.getTestOutputFolder(), "model"); int nTopics = 10; int nIterations = 50; float alpha = nTopics / 50.0f; float beta = 0.01f; String language = "en"; CollectionReaderDescription reader = createReaderDescription(TextReader.class, TextReader.PARAM_SOURCE_LOCATION, TXT_DIR, TextReader.PARAM_PATTERNS, TXT_FILE_PATTERN, TextReader.PARAM_LANGUAGE, language); AnalysisEngineDescription segmenter = createEngineDescription(BreakIteratorSegmenter.class); AnalysisEngineDescription estimator = createEngineDescription( MalletLdaTopicModelTrainer.class, MalletLdaTopicModelTrainer.PARAM_TARGET_LOCATION, modelFile, MalletLdaTopicModelTrainer.PARAM_N_ITERATIONS, nIterations, MalletLdaTopicModelTrainer.PARAM_N_TOPICS, nTopics, MalletLdaTopicModelTrainer.PARAM_ALPHA_SUM, alpha, MalletLdaTopicModelTrainer.PARAM_BETA, beta); SimplePipeline.runPipeline(reader, segmenter, estimator); assertTrue(modelFile.exists()); ParallelTopicModel model = ParallelTopicModel.read(modelFile); assertEquals(nTopics, model.getNumTopics()); } }