package com.scaleunlimited.cascading.ml;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import cascading.flow.FlowProcess;
import cascading.operation.BaseOperation;
import cascading.operation.Buffer;
import cascading.operation.BufferCall;
import cascading.operation.Function;
import cascading.operation.FunctionCall;
import cascading.operation.Identity;
import cascading.operation.OperationCall;
import cascading.operation.expression.ExpressionFilter;
import cascading.pipe.CoGroup;
import cascading.pipe.Each;
import cascading.pipe.Every;
import cascading.pipe.GroupBy;
import cascading.pipe.HashJoin;
import cascading.pipe.Pipe;
import cascading.pipe.SubAssembly;
import cascading.pipe.assembly.SumBy;
import cascading.pipe.joiner.LeftJoin;
import cascading.pipe.joiner.OuterJoin;
import cascading.tuple.Fields;
import cascading.tuple.Tuple;
import cascading.tuple.TupleEntry;
import com.scaleunlimited.cascading.NullContext;
import com.scaleunlimited.cascading.UniqueCount;
@SuppressWarnings("serial")
public class TopTermsByTfIdf extends SubAssembly {
private static class ExtractTerms extends BaseOperation<NullContext> implements Function<NullContext> {
private ITermsParser _parser;
private transient Tuple _result;
private transient Tuple _emptyTerm;
public ExtractTerms(ITermsParser parser) {
super(new Fields("term", "tf", "joiner"));
_parser = parser;
}
@Override
public void prepare(FlowProcess flowProcess, OperationCall<NullContext> operationCall) {
super.prepare(flowProcess, operationCall);
_result = new Tuple("", 0.0f, "");
_emptyTerm = new Tuple("", 0.0f, "x");
}
@Override
public void operate(FlowProcess flowProcess, FunctionCall<NullContext> functionCall) {
_parser.reset(functionCall.getArguments().getString("text"));
// TODO use fastutils code, e.g. Object2IntOpenHashMap. Or get long hash from
// string, but then we'd need to re-join top terms (by hash) against the actual
// term at the end of the workflow
Map<String, Integer> terms = new HashMap<String, Integer>();
int totalTerms = 0;
for (String term : _parser) {
if ((term == null) || term.isEmpty()) {
// Ignore empty terms, as that messes with our logic below.
continue;
}
totalTerms += 1;
Integer termCount = terms.get(term);
if (termCount == null) {
terms.put(term, 1);
} else {
terms.put(term, termCount + 1);
}
}
for (String term : terms.keySet()) {
_result.setString(0, term);
_result.setFloat(1, (float)terms.get(term)/(float)totalTerms);
functionCall.getOutputCollector().add(_result);
}
// And output special empty term, that we use to get a count of total documents.
functionCall.getOutputCollector().add(_emptyTerm);
}
}
private static class TermAndScore implements Comparable<TermAndScore> {
String _term;
double _score;
public TermAndScore(String term, double score) {
_term = term;
_score = score;
}
@Override
public int compareTo(TermAndScore o) {
if (_score > o._score) {
return -1;
} else if (_score < o._score) {
return 1;
} else {
return 0;
}
}
}
private static class CalcLLR extends BaseOperation<NullContext> implements Buffer<NullContext> {
private int _numTerms;
public CalcLLR(int numTerms) {
super(new Fields("docid", "terms", "scores"));
_numTerms = numTerms;
}
@Override
public void operate(FlowProcess flowProcess, BufferCall<NullContext> bufferCall) {
String docid = bufferCall.getGroup().getString("docid");
Iterator<TupleEntry> iter = bufferCall.getArgumentsIterator();
if (!iter.hasNext()) {
throw new RuntimeException(String.format("Impossible situation - group for docid %s has no members", docid));
}
TupleEntry te = iter.next();
String term = te.getString("term");
if (!term.isEmpty()) {
throw new RuntimeException(String.format("Impossible situation - first term for docid %s isn't empty", docid));
}
int docTermCount = te.getInteger("term_count");
int globalTermCount = te.getInteger("total_count");
// Now we can start iterating over the terms for this document, calculating their LLR score and keeping
// the top N
List<TermAndScore> queue = new ArrayList<TermAndScore>(_numTerms);
while (iter.hasNext()) {
te = iter.next();
int termCount = te.getInteger("term_count");
int termTotalCount = te.getInteger("total_terms");
// k11 is the count of this term in this document
long k11 = termCount;
// k12 is the count of all other terms in this document
long k12 = docTermCount - termCount;
// k21 is the count of this term in all other documents.
long k21 = termTotalCount - termCount;
// k22 is the count of all other terms in all other documents
// TODO KKr - should this then be subtracting k21 here, not docTermCount?
// And also k12?
long k22 = globalTermCount - docTermCount;
double score = LogLikelihood.logLikelihoodRatio(k11, k12, k21, k22);
if (queue.size() < _numTerms) {
queue.add(new TermAndScore(te.getString("term"), score));
Collections.sort(queue);
} else if (queue.get(_numTerms - 1)._score < score) {
queue.add(new TermAndScore(te.getString("term"), score));
Collections.sort(queue);
}
}
// At the end we'll have the top terms & scores
Tuple terms = new Tuple();
Tuple scores = new Tuple();
for (TermAndScore tas : queue) {
terms.add(tas._term);
scores.add(tas._score);
}
bufferCall.getOutputCollector().add(new Tuple(docid, terms, scores));
}
}
// TODO also take in IScorer scorer, which has methods to calculate TF score
// from term count and document count, and IDF score from doc count & total docs.
// TODO also take in IFilterTerms filter, which has methods to filter out
// terms based on term info (term itself, term count in doc, total terms in
// doc, TF score) and doc info (term itself, doc count, total docs, IDF score)
// TODO take in Fields param wich has field(s) for use as docid.
// TODO take in Fields param which has field for text.
public TopTermsByTfIdf(Pipe docsPipe, ITermsParser parser, int numTerms) {
super(docsPipe);
// We assume each document has a docid field, and a text field
Pipe termsPipe = new Pipe("terms", docsPipe);
termsPipe = new Each(termsPipe, new Fields("text"), new ExtractTerms(parser), Fields.REPLACE);
// We've got (docid, term, tf, "") for regular tuples, and (docid, "", 0.0f, "x") for
// special tuples used to count the total number of documents.
// We need term, IDF score. To get that, we need to calculate doc count for each term, and total doc count,
// and do the division.
Pipe docCountPipe = new Pipe("doc count", termsPipe);
docCountPipe = new UniqueCount(docCountPipe, new Fields("term"), new Fields("docid"), new Fields("doc_count"));
// In the docCountPipe, we now have (term, doc_count). One of these tuples has ("", total doc count), so
// we want to do a filter & HashJoin against itself
Pipe totalDocCountPipe = new Pipe("total doc count", docCountPipe);
totalDocCountPipe = new Each(totalDocCountPipe, new Fields("term"), new ExpressionFilter("term.isEmpty()", String.class));
totalDocCountPipe = new Each(totalDocCountPipe, new Fields("doc_count"), new Identity());
// Now we can do a cross join, so that the total doc count is joined to every one of our
// incoming tuples.
docCountPipe = new HashJoin(docCountPipe, Fields.NONE, totalDocCountPipe, Fields.NONE, new OuterJoin());
// Generate term, total count. This will
// include the empty term "" which will be the total count of all terms.
Pipe termCountPipe = new Pipe("term count", termsPipe);
termCountPipe = new SumBy(termCountPipe, new Fields("term"), new Fields("count"), new Fields("total_count"), Integer.class);
// Join termCountsPipe with our termsPipe by term, so we get
// docid, term, doc term count, total term count
// This will include docid, "", total terms in doc, global terms count
Pipe allTermData = new CoGroup( termsPipe, new Fields("term"),
termCountPipe, new Fields("term"),
new Fields("docid", "term", "term_count", "term_ignore", "total_count"),
new LeftJoin());
allTermData = new Each(allTermData, new Fields("docid", "term", "term_count", "total_count"), new Identity());
allTermData = new GroupBy(allTermData, new Fields("docid"), new Fields("term"));
allTermData = new Every(allTermData, new CalcLLR(numTerms));
setTails(allTermData);
}
}