/* * Copyright 2015 * Ubiquitous Knowledge Processing (UKP) Lab * Technische Universität Darmstadt * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package de.tudarmstadt.ukp.dkpro.core.mallet.lda; import cc.mallet.topics.ParallelTopicModel; import cc.mallet.types.Alphabet; import cc.mallet.types.IDSorter; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import java.io.File; import java.io.IOException; import java.util.*; public class MalletLdaTopicModelUtils { private static final Log LOGGER = LogFactory.getLog(MalletLdaTopicModelUtils.class); /** * Retrieve the top n topic words for each topic in the given model. * * @param modelFile * the model file * @param nWords * the maximum number of words to retrieve * @param normalize * normalize the word weights ? * * @return a list of maps where each map represents a topic, mapping words to weights * @throws IOException * if the model cannot be read */ public static List<Map<String, Double>> getTopWords(File modelFile, int nWords, boolean normalize) throws IOException { LOGGER.info("Reading model file " + modelFile + "..."); ParallelTopicModel model; try { model = ParallelTopicModel.read(modelFile); } catch (Exception e) { throw new IOException(e); } Alphabet alphabet = model.getAlphabet(); List<Map<String, Double>> topics = new ArrayList<>(model.getNumTopics()); /* iterate over topics */ for (TreeSet<IDSorter> topic : model.getSortedWords()) { Map<String, Double> topicWords = new HashMap<>(nWords); /* iterate over word IDs in topic (sorted by weight) */ for (IDSorter id : topic) { double weight = normalize ? id.getWeight() / alphabet.size() : id.getWeight(); // normalize String word = (String) alphabet.lookupObject(id.getID()); topicWords.put(word, weight); if (topicWords.size() >= nWords) { break; // go to next topic } } topics.add(topicWords); } return topics; } /** * Print the top n words of each topic into a file. * * @param modelFile * the model file * @param targetFile * the file in which the topic words are written * @param nWords * the number of words to extract * @throws IOException * if the model file cannot be read or if the target file cannot be written */ public static void printTopicWords(File modelFile, File targetFile, int nWords) throws IOException { boolean newLineAfterEachWord = false; ParallelTopicModel model; try { model = ParallelTopicModel.read(modelFile); } catch (Exception e) { throw new IOException(e); } model.printTopWords(targetFile, nWords, newLineAfterEachWord); } }