package com.scaleunlimited.cascading.ml;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import cascading.flow.FlowProcess;
import cascading.operation.BaseOperation;
import cascading.operation.Buffer;
import cascading.operation.BufferCall;
import cascading.operation.Debug;
import cascading.operation.DebugLevel;
import cascading.operation.Function;
import cascading.operation.FunctionCall;
import cascading.operation.OperationCall;
import cascading.operation.aggregator.First;
import cascading.operation.expression.ExpressionFunction;
import cascading.pipe.Each;
import cascading.pipe.Every;
import cascading.pipe.GroupBy;
import cascading.pipe.Pipe;
import cascading.pipe.SubAssembly;
import cascading.pipe.assembly.CountBy;
import cascading.pipe.assembly.Unique;
import cascading.tuple.Fields;
import cascading.tuple.Tuple;
import cascading.tuple.TupleEntry;
/**
* We get passed a Tuple that has two two fields in it - a document id, and a "terms" string.
* We'll output the the top N similar documents (with similarity scores) for each unique doc ID.
*
*/
@SuppressWarnings("serial")
public class SimHash extends SubAssembly {
public static final String SIMILAR_DOC_ID_FN = "SimHash_similarDocId";
public static final Fields SIMILAR_DOC_ID_FIELD = new Fields(SIMILAR_DOC_ID_FN);
public static final String SIMILARITY_FN = "SimHash_similarity";
public static final Fields SIMILARITY_FIELD = new Fields(SIMILARITY_FN);
private static final String TERM_HASH_FN = "SimHash_termHash";
private static final Fields TERM_HASH_FIELD = new Fields(TERM_HASH_FN);
private static final Fields NUM_SIMILAR_DOCS_FIELD = new Fields("SimHash_numSimilarDocs");
private static class EmitMatchingDocs extends BaseOperation<Void> implements Buffer<Void> {
private String _docIdFieldname;
private transient Tuple _result;
public EmitMatchingDocs(String docIdFieldname) {
super(new Fields(docIdFieldname, SIMILAR_DOC_ID_FN));
_docIdFieldname = docIdFieldname;
}
@Override
public void prepare(FlowProcess flowProcess, OperationCall<Void> operationCall) {
super.prepare(flowProcess, operationCall);
_result = Tuple.size(2);
}
@Override
public void operate(FlowProcess flowProcess, BufferCall<Void> bufferCall) {
// We're grouped on hash, sorted by doc id. So each group will have
// some array of doc ids, where we need to emit all pairs, but we only need to
// emit one of each; (A, B) only, not (B, A) as well.
Iterator<TupleEntry> iter = bufferCall.getArgumentsIterator();
List<Object> pendingDocIds = new ArrayList<Object>();
while (iter.hasNext()) {
TupleEntry te = iter.next();
Object nextDocId = te.getObject(_docIdFieldname);
_result.set(1, nextDocId);
for (Object docId : pendingDocIds) {
_result.set(0, docId);
bufferCall.getOutputCollector().add(_result);
}
pendingDocIds.add(nextDocId);
}
}
}
private static byte[] getUTF8Bytes(String str) {
try {
return str.getBytes("UTF-8");
} catch (UnsupportedEncodingException e) {
throw new RuntimeException("Impossible missing charset exception", e);
}
}
/**
* Generate a 64-bit JOAAT hash from the bytes of <s>
*
* @param s String to hash
* @return 64-bit hash
*/
private static long getLongHash(String s) {
byte[] bytes = getUTF8Bytes(s);
return getLongHash(bytes, 0, bytes.length);
}
/**
* Generate a 64-bit JOAAT hash from the given byte array
*
* @param b Bytes to hash
* @param offset starting offset
* @param length number of bytes to hash
* @return 64-bit hash
*/
private static long getLongHash(byte[] b, int offset, int length) {
long result = 0;
for (int i = 0; i < length; i++) {
byte curByte = b[offset + i];
int h = (int)curByte;
result += h & 0x0FFL;
result += (result << 20);
result ^= (result >> 12);
}
result += (result << 6);
result ^= (result >> 22);
result += (result << 30);
return result;
}
private static class CalcHash extends BaseOperation<Void> implements Function<Void> {
private transient Tuple _result;
public CalcHash() {
super(TERM_HASH_FIELD);
}
@Override
public void prepare(FlowProcess flowProcess, OperationCall<Void> operationCall) {
super.prepare(flowProcess, operationCall);
_result = Tuple.size(1);
}
@Override
public void operate(FlowProcess flowProcess, FunctionCall<Void> functionCall) {
Object term = functionCall.getArguments().getObject(0);
long hash = 0;
if (term instanceof String) {
hash = getLongHash((String)term);
} else {
hash = term.hashCode();
}
_result.setLong(0, hash);
functionCall.getOutputCollector().add(_result);
}
}
public SimHash(Pipe sourcePipe, String docIdFieldname, String termFieldname, int numHashes, int numSimilarDocs) {
super(sourcePipe);
// Calculate hash for each tuple. First leave one unique value per document.
// FUTURE we could defer this until a custom Buffer instead of First(numHashes), to avoid
// an extra job in the workflow.
sourcePipe = new Unique(sourcePipe, new Fields(docIdFieldname, termFieldname));
// sourcePipe = new Each(sourcePipe, new Fields(termFieldname), new ExpressionFunction(TERM_HASH_FIELD, "$0.hashCode()", String.class), Fields.SWAP);
sourcePipe = new Each(sourcePipe, DebugLevel.VERBOSE, new Debug("terms", true));
sourcePipe = new Each(sourcePipe, new Fields(termFieldname), new CalcHash(), Fields.SWAP);
sourcePipe = new Each(sourcePipe, DebugLevel.VERBOSE, new Debug("raw hashes", true));
// Group by doc, sort by hash, pick the first numHashes
sourcePipe = new GroupBy("Pick min hashes", sourcePipe, new Fields(docIdFieldname), TERM_HASH_FIELD);
// TODO want to fail if we don't get at least numHashes per doc
sourcePipe = new Every(sourcePipe, new First(numHashes), Fields.RESULTS);
sourcePipe = new Each(sourcePipe, DebugLevel.VERBOSE, new Debug("min hashes", true));
// Group by hash, sort by doc, emit matches
sourcePipe = new GroupBy("Emit matching docs", sourcePipe, TERM_HASH_FIELD, new Fields(docIdFieldname));
sourcePipe = new Every(sourcePipe, new EmitMatchingDocs(docIdFieldname), Fields.RESULTS);
sourcePipe = new Each(sourcePipe, DebugLevel.VERBOSE, new Debug("matching docs", true));
// Group by doc id and "matching" doc id, count occurrences
sourcePipe = new CountBy("Count matching docs", sourcePipe, new Fields(docIdFieldname, SIMILAR_DOC_ID_FN), NUM_SIMILAR_DOCS_FIELD);
sourcePipe = new Each(sourcePipe, DebugLevel.VERBOSE, new Debug("matching doc count", true));
// Calculate similarity score.
sourcePipe = new Each(sourcePipe, NUM_SIMILAR_DOCS_FIELD, new ExpressionFunction(SIMILARITY_FIELD, String.format("$0/%s", numHashes), Float.class), Fields.SWAP);
sourcePipe = new Each(sourcePipe, DebugLevel.VERBOSE, new Debug("SimHash results", true));
// Limit to top numSimilarDocs by score
sourcePipe = new GroupBy("Emit top matches", sourcePipe, new Fields(docIdFieldname), SIMILARITY_FIELD, true);
sourcePipe = new Every(sourcePipe, new First(numSimilarDocs), Fields.RESULTS);
sourcePipe = new Each(sourcePipe, DebugLevel.VERBOSE, new Debug("SimHash top results", true));
setTails(sourcePipe);
}
}