package com.scaleunlimited.cascading.ml; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import cascading.flow.FlowProcess; import cascading.operation.BaseOperation; import cascading.operation.Buffer; import cascading.operation.BufferCall; import cascading.operation.Function; import cascading.operation.FunctionCall; import cascading.operation.Identity; import cascading.pipe.CoGroup; import cascading.pipe.Each; import cascading.pipe.Every; import cascading.pipe.GroupBy; import cascading.pipe.Pipe; import cascading.pipe.SubAssembly; import cascading.pipe.assembly.AggregateBy; import cascading.pipe.assembly.SumBy; import cascading.pipe.joiner.LeftJoin; import cascading.tuple.Fields; import cascading.tuple.Tuple; import cascading.tuple.TupleEntry; import com.scaleunlimited.cascading.NullContext; @SuppressWarnings("serial") public class TopTermsByLLR extends SubAssembly { private static final Logger LOGGER = LoggerFactory.getLogger(TopTermsByLLR.class); private static class ExtractTerms extends BaseOperation<NullContext> implements Function<NullContext> { private ITermsParser _parser; public ExtractTerms(ITermsParser parser) { super(new Fields("term", "term_count")); _parser = parser; } @SuppressWarnings("rawtypes") @Override public void operate(FlowProcess flowProcess, FunctionCall<NullContext> functionCall) { _parser.reset(functionCall.getArguments().getString(0)); // TODO use fastutils code, e.g. Object2IntOpenHashMap. Or get long hash from // string, but then we'd need to re-join top terms (by hash) against the actual // term at the end of the workflow Map<String, Integer> terms = new HashMap<String, Integer>(); int totalTerms = 0; for (String term : _parser) { totalTerms += 1; Integer termCount = terms.get(term); if (termCount == null) { terms.put(term, 1); } else { terms.put(term, termCount + 1); } } for (String term : terms.keySet()) { functionCall.getOutputCollector().add(new Tuple(term, terms.get(term))); } functionCall.getOutputCollector().add(new Tuple(null, totalTerms)); } } private static class TermAndScore implements Comparable<TermAndScore> { String _term; double _score; public TermAndScore(String term, double score) { _term = term; _score = score; } @Override public int compareTo(TermAndScore o) { if (_score > o._score) { return -1; } else if (_score < o._score) { return 1; } else { return 0; } } } private static class TermAndCounts { public boolean atEnd; public String curTerm; public int docTermCount; public int totalTermCount; // Save information from next term we find. public String nextTerm; public int nextDocTermCount; public int nextTotalTermCount; @Override public String toString() { return String.format("\"%s\"=%d/%d, next \"%s\"=%d", curTerm, docTermCount, totalTermCount, nextTerm, nextDocTermCount); } } private static class CalcLLR extends BaseOperation<NullContext> implements Buffer<NullContext> { private ITermsFilter _filter; private ITermsParser _parser; public CalcLLR(ITermsParser parser, ITermsFilter filter) { super(new Fields("terms", "scores")); _parser = parser; _filter = filter; } @SuppressWarnings("rawtypes") @Override public void operate(FlowProcess flowProcess, BufferCall<NullContext> bufferCall) { TupleEntry docid = bufferCall.getGroup(); Iterator<TupleEntry> iter = bufferCall.getArgumentsIterator(); if (!iter.hasNext()) { throw new RuntimeException(String.format("Impossible situation - group for docid %s has no members", docid)); } TermAndCounts termCounts = new TermAndCounts(); countTerms(iter, termCounts); if (termCounts.curTerm != null) { throw new RuntimeException(String.format("Impossible situation - first term for docid %s isn't null", docid)); } int globalTermCount = termCounts.totalTermCount; int docTermCount = termCounts.docTermCount; // We can get multiple // Now we can start iterating over the terms for this document, calculating their LLR score and keeping // the top N int maxResults = _filter.getMaxResults(); List<TermAndScore> queue = new ArrayList<TermAndScore>(maxResults); while (countTerms(iter, termCounts)) { // LOGGER.info(termCounts); int termCount = termCounts.docTermCount; int termTotalCount = termCounts.totalTermCount; // k11 is the count of this term in this document long k11 = termCount; // k12 is the count of all other terms in this document long k12 = docTermCount - termCount; // k21 is the count of this term in all other documents. long k21 = termTotalCount - termCount; // k22 is the count of all other terms in all other documents long k22 = globalTermCount - docTermCount - k21; double score; try { score = LogLikelihood.rootLogLikelihoodRatio(k11, k12, k21, k22); } catch (IllegalArgumentException e) { LOGGER.warn(String.format("Invalid LLR values for %s in %s: k11=%d, k12=%d, k21=%d, k22=%d", termCounts.curTerm, docid.getTuple(), k11, k12, k21, k22)); continue; } // See if any filtering is needed. if (_filter.filter(score, termCounts.curTerm, _parser)) { continue; } if (queue.size() < maxResults) { queue.add(new TermAndScore(termCounts.curTerm, score)); if (queue.size() == maxResults) { // Set up for next call, where last score must be lowest Collections.sort(queue); } } else if (queue.get(maxResults - 1)._score < score) { queue.set(maxResults - 1, new TermAndScore(termCounts.curTerm, score)); Collections.sort(queue); } } // At the end we'll have the top terms & scores Tuple terms = new Tuple(); Tuple scores = new Tuple(); for (TermAndScore tas : queue) { terms.add(tas._term); scores.add(tas._score); } bufferCall.getOutputCollector().add(new Tuple(terms, scores)); } private boolean countTerms(Iterator<TupleEntry> iter, TermAndCounts termCounts) { if (termCounts.atEnd) { return false; } // While the iterator is returning terms equal to what's in termCounts, keep // counting the number of terms. Once the term changes, or we have no more // terms, return results. termCounts.curTerm = termCounts.nextTerm; termCounts.docTermCount = termCounts.nextDocTermCount; termCounts.totalTermCount = termCounts.nextTotalTermCount; termCounts.nextTerm = null; termCounts.nextDocTermCount = 0; termCounts.nextTotalTermCount = 0; String curTerm = termCounts.curTerm; while (iter.hasNext()) { TupleEntry te = iter.next(); String newTerm = te.getString("term"); int newTermCount = te.getInteger("term_count"); if ((curTerm == null) && (newTerm == null)) { // We have a match. Special case for when we're processing the special // null term, since we won't have a previous item. if (termCounts.totalTermCount == 0) { termCounts.totalTermCount = te.getInteger("total_count"); } } else if ((curTerm == null) || !newTerm.equals(curTerm)) { // Switching terms, return what we've got. termCounts.nextTerm = newTerm; termCounts.nextDocTermCount = newTermCount; termCounts.nextTotalTermCount = te.getInteger("total_count"); return true; } termCounts.docTermCount += newTermCount; } // We ran out of terms, so we're done. termCounts.atEnd = true; return true; } } public TopTermsByLLR(Pipe docsPipe, ITermsParser parser, final int maxTerms) { this(docsPipe, parser, new ITermsFilter() { @Override public int getMaxResults() { return maxTerms; } @Override public boolean filter(double llrScore, String term, ITermsParser parser) { return false; } }, new Fields("docId"), new Fields("text"), AggregateBy.CompositeFunction.DEFAULT_THRESHOLD); } /** * Given a pipe containing tuples with <docIdFields> and a <textField>, first parse * the text using <parser>, then filter using <filter>. Score the resulting terms * using LLR, and only emit terms that pass <filter>. * * @param docsPipe * @param parser * @param filter * @param docIdFields * @param textField * @param threshold - size of LRU cache for map-side pre-aggregation. */ public TopTermsByLLR(Pipe docsPipe, ITermsParser parser, ITermsFilter filter, Fields docIdFields, Fields textField, int threshold) { super(docsPipe); // We assume each document has one or more fields that identify each "document", and a text field Pipe termsPipe = new Pipe("terms", docsPipe); termsPipe = new Each(termsPipe, textField, new ExtractTerms(parser), Fields.SWAP); // We've got docid, term, term count. Generate term, total count. This will // include the null term which will be the total count of all terms. Pipe termCountPipe = new Pipe("term count", termsPipe); termCountPipe = new SumBy(termCountPipe, new Fields("term"), new Fields("term_count"), new Fields("total_count"), Integer.class, threshold); // termCountPipe = new Each(termCountPipe, new Debug("summed", true)); // Join termCountsPipe with our termsPipe by term, so we get // docid, term, doc term count, total term count // This will include docid, null, total terms in doc, global terms count Pipe allTermData = new CoGroup( termsPipe, new Fields("term"), termCountPipe, new Fields("term"), docIdFields.append(new Fields("term", "term_count", "term_ignore", "total_count")), new LeftJoin()); Fields termFields = new Fields("term", "term_count", "total_count"); allTermData = new Each(allTermData, docIdFields.append(termFields), new Identity()); // allTermData = new Each(allTermData, new Debug("grouped", true)); allTermData = new GroupBy(allTermData, docIdFields, new Fields("term")); allTermData = new Every(allTermData, termFields, new CalcLLR(parser, filter), Fields.SWAP); setTails(allTermData); } }