/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.utils.vectors.lucene; import java.io.File; import java.io.IOException; import java.io.OutputStreamWriter; import java.io.Writer; import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import com.google.common.base.Charsets; import com.google.common.io.Closeables; import com.google.common.io.Files; import org.apache.commons.cli2.CommandLine; import org.apache.commons.cli2.Group; import org.apache.commons.cli2.Option; import org.apache.commons.cli2.OptionException; import org.apache.commons.cli2.builder.ArgumentBuilder; import org.apache.commons.cli2.builder.DefaultOptionBuilder; import org.apache.commons.cli2.builder.GroupBuilder; import org.apache.commons.cli2.commandline.Parser; import org.apache.hadoop.fs.Path; import org.apache.lucene.document.FieldSelector; import org.apache.lucene.document.SetBasedFieldSelector; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermDocs; import org.apache.lucene.index.TermEnum; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.util.OpenBitSet; import org.apache.mahout.clustering.WeightedVectorWritable; import org.apache.mahout.common.CommandLineUtil; import org.apache.mahout.common.commandline.DefaultOptionCreator; import org.apache.mahout.math.NamedVector; import org.apache.mahout.math.Vector; import org.apache.mahout.math.stats.LogLikelihood; import org.apache.mahout.utils.clustering.ClusterDumper; import org.apache.mahout.utils.vectors.TermEntry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Get labels for the cluster using Log Likelihood Ratio (LLR). * <p/> *"The most useful way to think of this (LLR) is as the percentage of in-cluster documents that have the * feature (term) versus the percentage out, keeping in mind that both percentages are uncertain since we have * only a sample of all possible documents." - Ted Dunning * <p/> * More about LLR can be found at : http://tdunning.blogspot.com/2008/03/surprise-and-coincidence.html */ public class ClusterLabels { private static final Logger log = LoggerFactory.getLogger(ClusterLabels.class); public static final int DEFAULT_MIN_IDS = 50; public static final int DEFAULT_MAX_LABELS = 25; private final String indexDir; private final String contentField; private String idField; private final Map<Integer, List<WeightedVectorWritable>> clusterIdToPoints; private String output; private int minNumIds; private int maxLabels; public ClusterLabels(Path seqFileDir, Path pointsDir, String indexDir, String contentField, int minNumIds, int maxLabels) { this.indexDir = indexDir; this.contentField = contentField; this.minNumIds = minNumIds; this.maxLabels = maxLabels; this.minNumIds = DEFAULT_MIN_IDS; this.maxLabels = DEFAULT_MAX_LABELS; ClusterDumper clusterDumper = new ClusterDumper(seqFileDir, pointsDir); this.clusterIdToPoints = clusterDumper.getClusterIdToPoints(); } public void getLabels() throws IOException { Writer writer; if (this.output == null) { writer = new OutputStreamWriter(System.out); } else { writer = Files.newWriter(new File(this.output), Charsets.UTF_8); } try { for (Map.Entry<Integer, List<WeightedVectorWritable>> integerListEntry : clusterIdToPoints.entrySet()) { List<WeightedVectorWritable> wvws = integerListEntry.getValue(); List<TermInfoClusterInOut> termInfos = getClusterLabels(integerListEntry.getKey(), wvws); if (termInfos != null) { writer.write('\n'); writer.write("Top labels for Cluster "); writer.write(String.valueOf(integerListEntry.getKey())); writer.write(" containing "); writer.write(String.valueOf(wvws.size())); writer.write(" vectors"); writer.write('\n'); writer.write("Term \t\t LLR \t\t In-ClusterDF \t\t Out-ClusterDF "); writer.write('\n'); for (TermInfoClusterInOut termInfo : termInfos) { writer.write(termInfo.getTerm()); writer.write("\t\t"); writer.write(String.valueOf(termInfo.getLogLikelihoodRatio())); writer.write("\t\t"); writer.write(String.valueOf(termInfo.getInClusterDF())); writer.write("\t\t"); writer.write(String.valueOf(termInfo.getOutClusterDF())); writer.write('\n'); } } } } finally { Closeables.closeQuietly(writer); } } /** * Get the list of labels, sorted by best score. */ protected List<TermInfoClusterInOut> getClusterLabels(Integer integer, Collection<WeightedVectorWritable> wvws) throws IOException { if (wvws.size() < minNumIds) { log.info("Skipping small cluster {} with size: {}", integer, wvws.size()); return null; } log.info("Processing Cluster {} with {} documents", integer, wvws.size()); Directory dir = FSDirectory.open(new File(this.indexDir)); IndexReader reader = IndexReader.open(dir, false); log.info("# of documents in the index {}", reader.numDocs()); Collection<String> idSet = new HashSet<String>(); for (WeightedVectorWritable wvw : wvws) { Vector vector = wvw.getVector(); if (vector instanceof NamedVector) { idSet.add(((NamedVector) vector).getName()); } } int numDocs = reader.numDocs(); OpenBitSet clusterDocBitset = getClusterDocBitset(reader, idSet, this.idField); log.info("Populating term infos from the index"); /** * This code is as that of CachedTermInfo, with one major change, which is to get the document frequency. * * Since we have deleted the documents out of the cluster, the document frequency for a term should only * include the in-cluster documents. The document frequency obtained from TermEnum reflects the frequency * in the entire index. To get the in-cluster frequency, we need to query the index to get the term * frequencies in each document. The number of results of this call will be the in-cluster document * frequency. */ TermEnum te = reader.terms(new Term(contentField, "")); Map<String, TermEntry> termEntryMap = new LinkedHashMap<String, TermEntry>(); try { int count = 0; do { Term term = te.term(); if (term == null || !term.field().equals(contentField)) { break; } OpenBitSet termBitset = new OpenBitSet(reader.maxDoc()); // Generate bitset for the term TermDocs termDocs = reader.termDocs(term); while (termDocs.next()) { termBitset.set(termDocs.doc()); } // AND the term's bitset with cluster doc bitset to get the term's in-cluster frequency. // This modifies the termBitset, but that's fine as we are not using it anywhere else. termBitset.and(clusterDocBitset); int inclusterDF = (int) termBitset.cardinality(); TermEntry entry = new TermEntry(term.text(), count++, inclusterDF); termEntryMap.put(entry.getTerm(), entry); } while (te.next()); } finally { Closeables.closeQuietly(te); } List<TermInfoClusterInOut> clusteredTermInfo = new LinkedList<TermInfoClusterInOut>(); int clusterSize = wvws.size(); for (TermEntry termEntry : termEntryMap.values()) { int corpusDF = reader.terms(new Term(this.contentField, termEntry.getTerm())).docFreq(); int outDF = corpusDF - termEntry.getDocFreq(); int inDF = termEntry.getDocFreq(); double logLikelihoodRatio = scoreDocumentFrequencies(inDF, outDF, clusterSize, numDocs); TermInfoClusterInOut termInfoCluster = new TermInfoClusterInOut(termEntry.getTerm(), inDF, outDF, logLikelihoodRatio); clusteredTermInfo.add(termInfoCluster); } Collections.sort(clusteredTermInfo); // Cleanup Closeables.closeQuietly(reader); termEntryMap.clear(); return clusteredTermInfo.subList(0, Math.min(clusteredTermInfo.size(), maxLabels)); } private static OpenBitSet getClusterDocBitset(IndexReader reader, Collection<String> idSet, String idField) throws IOException { int numDocs = reader.numDocs(); OpenBitSet bitset = new OpenBitSet(numDocs); FieldSelector idFieldSelector = new SetBasedFieldSelector(Collections.singleton(idField), Collections.<String>emptySet()); for (int i = 0; i < numDocs; i++) { String id; // Use Lucene's internal ID if idField is not specified. Else, get it from the document. if (idField == null) { id = Integer.toString(i); } else { id = reader.document(i, idFieldSelector).get(idField); } if (idSet.contains(id)) { bitset.set(i); } } log.info("Created bitset for in-cluster documents : {}", bitset.cardinality()); return bitset; } private static double scoreDocumentFrequencies(long inDF, long outDF, long clusterSize, long corpusSize) { long k12 = clusterSize - inDF; long k22 = corpusSize - clusterSize - outDF; return LogLikelihood.logLikelihoodRatio(inDF, k12, outDF, k22); } public String getIdField() { return idField; } public void setIdField(String idField) { this.idField = idField; } public String getOutput() { return output; } public void setOutput(String output) { this.output = output; } public static void main(String[] args) { DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); ArgumentBuilder abuilder = new ArgumentBuilder(); GroupBuilder gbuilder = new GroupBuilder(); Option indexOpt = obuilder.withLongName("dir").withRequired(true).withArgument( abuilder.withName("dir").withMinimum(1).withMaximum(1).create()) .withDescription("The Lucene index directory").withShortName("d").create(); Option outputOpt = obuilder.withLongName("output").withRequired(false).withArgument( abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription( "The output file. If not specified, the result is printed on console.").withShortName("o").create(); Option fieldOpt = obuilder.withLongName("field").withRequired(true).withArgument( abuilder.withName("field").withMinimum(1).withMaximum(1).create()) .withDescription("The content field in the index").withShortName("f").create(); Option idFieldOpt = obuilder.withLongName("idField").withRequired(false).withArgument( abuilder.withName("idField").withMinimum(1).withMaximum(1).create()).withDescription( "The field for the document ID in the index. If null, then the Lucene internal doc " + "id is used which is prone to error if the underlying index changes").withShortName("i").create(); Option seqOpt = obuilder.withLongName("seqFileDir").withRequired(true).withArgument( abuilder.withName("seqFileDir").withMinimum(1).withMaximum(1).create()).withDescription( "The directory containing Sequence Files for the Clusters").withShortName("s").create(); Option pointsOpt = obuilder.withLongName("pointsDir").withRequired(true).withArgument( abuilder.withName("pointsDir").withMinimum(1).withMaximum(1).create()).withDescription( "The directory containing points sequence files mapping input vectors to their cluster. ") .withShortName("p").create(); Option minClusterSizeOpt = obuilder.withLongName("minClusterSize").withRequired(false).withArgument( abuilder.withName("minClusterSize").withMinimum(1).withMaximum(1).create()).withDescription( "The minimum number of points required in a cluster to print the labels for").withShortName("m").create(); Option maxLabelsOpt = obuilder.withLongName("maxLabels").withRequired(false).withArgument( abuilder.withName("maxLabels").withMinimum(1).withMaximum(1).create()).withDescription( "The maximum number of labels to print per cluster").withShortName("x").create(); Option helpOpt = DefaultOptionCreator.helpOption(); Group group = gbuilder.withName("Options").withOption(indexOpt).withOption(idFieldOpt).withOption(outputOpt) .withOption(fieldOpt).withOption(seqOpt).withOption(pointsOpt).withOption(helpOpt) .withOption(maxLabelsOpt).withOption(minClusterSizeOpt).create(); try { Parser parser = new Parser(); parser.setGroup(group); CommandLine cmdLine = parser.parse(args); if (cmdLine.hasOption(helpOpt)) { CommandLineUtil.printHelp(group); return; } Path seqFileDir = new Path(cmdLine.getValue(seqOpt).toString()); Path pointsDir = new Path(cmdLine.getValue(pointsOpt).toString()); String indexDir = cmdLine.getValue(indexOpt).toString(); String contentField = cmdLine.getValue(fieldOpt).toString(); String idField = null; if (cmdLine.hasOption(idFieldOpt)) { idField = cmdLine.getValue(idFieldOpt).toString(); } String output = null; if (cmdLine.hasOption(outputOpt)) { output = cmdLine.getValue(outputOpt).toString(); } int maxLabels = DEFAULT_MAX_LABELS; if (cmdLine.hasOption(maxLabelsOpt)) { maxLabels = Integer.parseInt(cmdLine.getValue(maxLabelsOpt).toString()); } int minSize = DEFAULT_MIN_IDS; if (cmdLine.hasOption(minClusterSizeOpt)) { minSize = Integer.parseInt(cmdLine.getValue(minClusterSizeOpt).toString()); } ClusterLabels clusterLabel = new ClusterLabels(seqFileDir, pointsDir, indexDir, contentField, minSize, maxLabels); if (idField != null) { clusterLabel.setIdField(idField); } if (output != null) { clusterLabel.setOutput(output); } clusterLabel.getLabels(); } catch (OptionException e) { log.error("Exception", e); CommandLineUtil.printHelp(group); } catch (IOException e) { log.error("Exception", e); } } }