/* * Copyright 2014 * Ubiquitous Knowledge Processing (UKP) Lab * Technische Universität Darmstadt * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package de.tudarmstadt.ukp.dkpro.core.io.ditop; import static org.apache.uima.fit.util.JCasUtil.select; import java.io.BufferedWriter; import java.io.File; import java.io.FileWriter; import java.io.IOException; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.TreeSet; import org.apache.commons.collections4.Bag; import org.apache.commons.collections4.bag.HashBag; import org.apache.commons.io.FileUtils; import org.apache.uima.UimaContext; import org.apache.uima.analysis_engine.AnalysisEngineProcessException; import org.apache.uima.fit.descriptor.ConfigurationParameter; import org.apache.uima.fit.descriptor.MimeTypeCapability; import org.apache.uima.fit.descriptor.TypeCapability; import org.apache.uima.jcas.JCas; import org.apache.uima.jcas.cas.DoubleArray; import org.apache.uima.resource.ResourceInitializationException; import cc.mallet.topics.ParallelTopicModel; import cc.mallet.types.Alphabet; import cc.mallet.types.IDSorter; import de.tudarmstadt.ukp.dkpro.core.api.io.JCasFileWriter_ImplBase; import de.tudarmstadt.ukp.dkpro.core.api.metadata.type.DocumentMetaData; import de.tudarmstadt.ukp.dkpro.core.api.parameter.ComponentParameters; import de.tudarmstadt.ukp.dkpro.core.api.parameter.MimeTypes; import de.tudarmstadt.ukp.dkpro.core.mallet.type.TopicDistribution; /** * This annotator (consumer) writes output files as required by <a * href="https://ditop.hs8.de/">DiTop</a>. It requires JCas input annotated by * {@link de.tudarmstadt.ukp.dkpro.core.mallet.lda.MalletLdaTopicModelInferencer} using the same model. */ @MimeTypeCapability({MimeTypes.APPLICATION_X_DITOP}) @TypeCapability( inputs = { "de.tudarmstadt.ukp.dkpro.core.api.metadata.type.DocumentMetaData", "de.tudarmstadt.ukp.dkpro.core.mallet.type.TopicDistribution" }) public class DiTopWriter extends JCasFileWriter_ImplBase { private static final String FIELDSEPARATOR_CONFIGFILE = ";"; private final static String DOC_TOPICS_FILE = "topics.csv"; private final static String TOPIC_TERM_FILE = "topicTerm.txt"; private final static String TOPIC_TERM_MATRIX_FILE = "topicTermMatrix.txt"; private final static String TOPIC_SUMMARY_FILE = "topicTerm-T15.txt"; private final static String CONFIG_FILE = "config.all"; /** * The maximum number of topic words to extract. Default: 15 */ public static final String PARAM_MAX_TOPIC_WORDS = "maxTopicWords"; @ConfigurationParameter(name = PARAM_MAX_TOPIC_WORDS, mandatory = true, defaultValue = "15") private int maxTopicWords; /** * A Mallet file storing a serialized {@link ParallelTopicModel}. */ public static final String PARAM_MODEL_LOCATION = ComponentParameters.PARAM_MODEL_LOCATION; @ConfigurationParameter(name = PARAM_MODEL_LOCATION, mandatory = true) protected File modelLocation; /** * The corpus name is used to name the corresponding sub-directory and will be set in the * configuration file. */ public static final String PARAM_CORPUS_NAME = "corpusName"; @ConfigurationParameter(name = PARAM_CORPUS_NAME, mandatory = true) protected String corpusName; /** * Directory in which to store output files. */ public static final String PARAM_TARGET_LOCATION = ComponentParameters.PARAM_TARGET_LOCATION; @ConfigurationParameter(name = PARAM_TARGET_LOCATION, mandatory = true) protected File targetLocation; /** * If set to true, the new corpus will be appended to an existing config file. If false, the * existing file is overwritten. Default: true. */ public static final String PARAM_APPEND_CONFIG = "appendConfig"; @ConfigurationParameter(name = PARAM_APPEND_CONFIG, mandatory = true, defaultValue = "true") protected boolean appendConfig; /** * If set, only documents with one of the listed collection IDs are written, all others are * ignored. If this is empty (null), all documents are written. */ public final static String PARAM_COLLECTION_VALUES = "collectionValues"; @ConfigurationParameter(name = PARAM_COLLECTION_VALUES, mandatory = false) protected String[] collectionValues; /** * If true (default), only write documents with collection ids matching one of the collection * values exactly. If false, write documents with collection ids containing any of the * collection value string in collection while ignoring cases. */ public final static String PARAM_COLLECTION_VALUES_EXACT_MATCH = "collectionValuesExactMatch"; @ConfigurationParameter(name = PARAM_COLLECTION_VALUES_EXACT_MATCH, mandatory = true, defaultValue = "true") protected boolean collectionValuesExactMatch; private ParallelTopicModel model; private File collectionDir; protected Set<String> collectionValuesSet; private Bag<String> collectionCounter; protected BufferedWriter writerDocTopic; @Override public void initialize(UimaContext context) throws ResourceInitializationException { super.initialize(context); try { model = ParallelTopicModel.read(modelLocation); collectionDir = new File(targetLocation, corpusName + "_" + model.getNumTopics()); if (collectionDir.exists()) { getLogger().warn( String.format("%s' already exists, overwriting content.", collectionDir)); } collectionDir.mkdirs(); initializeTopicFile(); } catch (Exception e) { throw new ResourceInitializationException(e); } collectionValuesSet = collectionValues == null ? Collections.<String> emptySet() : new HashSet<>(Arrays.asList(collectionValues)); collectionCounter = new HashBag<>(); } @Override public void process(JCas aJCas) throws AnalysisEngineProcessException { for (TopicDistribution distribution : select(aJCas, TopicDistribution.class)) { String docName = getDocumentId(aJCas); String collectionId = getCollectionId(aJCas); /* Print and gather collection statistics */ if (collectionCounter.getCount(collectionId) == 0) { getLogger().info("New collection ID observed: " + collectionId); } collectionCounter.add(collectionId); try { writeDocTopic(distribution, docName, collectionId); } catch (IOException e) { throw new AnalysisEngineProcessException(e); } } } protected void writeDocTopic(TopicDistribution distribution, String docName, String collectionId) throws IOException { /* filter by collection id if PARAM_COLLECTION_VALUES is set */ if (collectionValuesSet.isEmpty() || collectionValuesSet.contains(collectionId)) { /* write documents to file */ writerDocTopic.write(collectionId + ","); writerDocTopic.write(docName); DoubleArray proportions = distribution.getTopicProportions(); for (double topicProb : proportions.toArray()) { writerDocTopic.write("," + topicProb); } writerDocTopic.newLine(); } } @Override public void collectionProcessComplete() throws AnalysisEngineProcessException { super.collectionProcessComplete(); getLogger().info("Collection statistics: " + collectionCounter.toString()); getLogger().info( collectionValuesSet.isEmpty() ? "Writing all documents." : "Writing documents from these collections only: " + collectionValuesSet.toString()); try { writerDocTopic.close(); writetermMatrixFiles(); writeConfigFile(); } catch (IOException e) { throw new AnalysisEngineProcessException(e); } } private void initializeTopicFile() throws IOException { File topicFile = new File(collectionDir, DOC_TOPICS_FILE); getLogger().info(String.format("Writing file '%s'.", topicFile.getPath())); writerDocTopic = new BufferedWriter(new FileWriter(topicFile)); /* Write header */ writerDocTopic.write("Class,Document"); for (int j = 0; j < model.numTopics; j++) { writerDocTopic.write(",T" + j); } writerDocTopic.newLine(); } /** * This method has been copied and slightly adapted from MalletLDA#printTopics in the original * DiTop code. * * @throws IOException * if a low-level I/O error occurs */ private void writetermMatrixFiles() throws IOException { File topicTermFile = new File(collectionDir, TOPIC_TERM_FILE); File topicTermMatrixFile = new File(collectionDir, TOPIC_TERM_MATRIX_FILE); File topicSummaryFile = new File(collectionDir, TOPIC_SUMMARY_FILE); BufferedWriter writerTopicTerm = new BufferedWriter(new FileWriter(topicTermFile)); BufferedWriter writerTopicTermMatrix = new BufferedWriter(new FileWriter( topicTermMatrixFile)); BufferedWriter writerTopicTermShort = new BufferedWriter(new FileWriter(topicSummaryFile)); getLogger().info(String.format("Writing file '%s'.", topicTermFile)); getLogger().info(String.format("Writing file '%s'.", topicTermMatrixFile)); getLogger().info(String.format("Writing file '%s'.", topicSummaryFile)); /* Write topic term associations */ Alphabet alphabet = model.getAlphabet(); for (int i = 0; i < model.getSortedWords().size(); i++) { writerTopicTerm.write("TOPIC " + i + ": "); writerTopicTermShort.write("TOPIC " + i + ": "); writerTopicTermMatrix.write("TOPIC " + i + ": "); /** topic for the label */ int count = 0; TreeSet<IDSorter> set = model.getSortedWords().get(i); for (IDSorter s : set) { if (count <= maxTopicWords) { writerTopicTermShort.write(alphabet.lookupObject(s.getID()) + ", "); } count++; writerTopicTerm.write(alphabet.lookupObject(s.getID()) + ", "); writerTopicTermMatrix.write(alphabet.lookupObject(s.getID()) + " (" + s.getWeight() + "), "); /** add to topic label */ } writerTopicTerm.newLine(); writerTopicTermShort.newLine(); writerTopicTermMatrix.newLine(); } writerTopicTermMatrix.close(); writerTopicTerm.close(); writerTopicTermShort.close(); } private void writeConfigFile() throws IOException { File configFile = new File(targetLocation, CONFIG_FILE); Map<String, Set<Integer>> corpora; // holds all corpus names mapped to (multiple) topic // numbers Set<Integer> currentCorpusTopicNumbers; // entry for the current, new topic if (appendConfig && configFile.exists()) { // read existing entries from config file corpora = readConfigFile(configFile); currentCorpusTopicNumbers = corpora.containsKey(corpusName) ? corpora.get(corpusName) : new HashSet<>(); } else { corpora = new HashMap<>(); currentCorpusTopicNumbers = new HashSet<>(1, 1); } currentCorpusTopicNumbers.add(model.getNumTopics()); corpora.put(corpusName, currentCorpusTopicNumbers); getLogger().info(String.format("Writing configuration file '%s'.", configFile.getPath())); BufferedWriter configWriter = new BufferedWriter(new FileWriter(configFile)); for (Entry<String, Set<Integer>> entry : corpora.entrySet()) { configWriter.write(entry.getKey()); for (Integer topicNumber : entry.getValue()) { configWriter.write(FIELDSEPARATOR_CONFIGFILE + topicNumber); } configWriter.newLine(); } configWriter.close(); } /** * Read config file in the form <corpus>;<ntopics>[;<ntopics>...] * <p> * Results in a Map <corpusname>:Set(ntopics1, ...) * * @param configFile * the config file to read * @return a map containing corpus names as keys and a set of topic numbers as values * @throws IOException * if an I/O error occurs. */ private static Map<String, Set<Integer>> readConfigFile(File configFile) throws IOException { Map<String, Set<Integer>> entries = new HashMap<>(); for (String line : FileUtils.readLines(configFile)) { String[] fields = line.split(FIELDSEPARATOR_CONFIGFILE); if (fields.length < 2) { throw new IllegalStateException(String.format( "Could not parse config file '%s': Invalid line:%n%s", configFile, line)); } if (entries.containsKey(fields[0])) { throw new IllegalStateException(String.format( "Could not parse config file '%s': duplicate corpus entry '%s'.", configFile, fields[0])); } Set<Integer> topicCounts = new HashSet<>(fields.length - 1); for (int i = 1; i < fields.length; i++) { try { topicCounts.add(Integer.parseInt(fields[i])); } catch (NumberFormatException e) { throw new IllegalStateException(String.format( "Could not parse config file '%s': Invalid topic number '%s'.", configFile, fields[i])); } } entries.put(fields[0], topicCounts); } return entries; } /** * Extract the collection id from the JCas. Uses {@link DocumentMetaData#getCollectionId()}, but * this method can be overwritten to select a different source for the collection id. * * @param aJCas * the JCas. * @return the collection id String or null if it is not available. */ protected String getCollectionId(JCas aJCas) { String collectionId = DocumentMetaData.get(aJCas).getCollectionId(); if (collectionId == null) { throw new IllegalStateException("Could not extract collection ID for document"); } if (!collectionValuesExactMatch && !collectionValuesSet.contains(collectionId)) { collectionId = expandCollectionId(collectionId); } return collectionId; } /** * This method checks whether any of the specified collection values contains the given String. * If it does, returns the matching value; if not, it returns the original value. * * @param collectionId * the collection ID. * @return the first entry from {@code collectionValuesSet} that contains the (lowercased) * {@code collectionId} or the input {@code collectionId}. */ protected String expandCollectionId(String collectionId) { assert !collectionValuesExactMatch; for (String value : collectionValuesSet) { if (collectionId.toLowerCase().contains(value.toLowerCase())) { getLogger().debug( String.format("Changing collection ID from '%s' to '%s'.", collectionId, value)); return value; } } return collectionId; } /** * Extract the document id from the JCas. Uses {@link DocumentMetaData#getDocumentId()}, but * this method can be overwritten to select a different source for the document id. * * @param aJCas * the JCas. * @return the document id string or null if it is not available. */ protected String getDocumentId(JCas aJCas) throws IllegalStateException { String docName = DocumentMetaData.get(aJCas).getDocumentId(); if (docName == null) { throw new IllegalStateException("Could not extract document ID from metadata."); } return docName; } }