/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.lucene.analysis.minhash; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.TreeSet; import org.apache.lucene.analysis.TokenFilter; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.analysis.tokenattributes.OffsetAttribute; import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute; import org.apache.lucene.analysis.tokenattributes.PositionLengthAttribute; import org.apache.lucene.analysis.tokenattributes.TypeAttribute; /** * Generate min hash tokens from an incoming stream of tokens. The incoming tokens would typically be 5 word shingles. * * The number of hashes used and the number of minimum values for each hash can be set. You could have 1 hash and keep * the 100 lowest values or 100 hashes and keep the lowest one for each. Hashes can also be bucketed in ranges over the * 128-bit hash space, * * A 128-bit hash is used internally. 5 word shingles from 10e5 words generate 10e25 combinations So a 64 bit hash would * have collisions (1.8e19) * * When using different hashes 32 bits are used for the hash position leaving scope for 8e28 unique hashes. A single * hash will use all 128 bits. * */ public class MinHashFilter extends TokenFilter { private static final int HASH_CACHE_SIZE = 512; private static final LongPair[] cachedIntHashes = new LongPair[HASH_CACHE_SIZE]; public static final int DEFAULT_HASH_COUNT = 1; public static final int DEFAULT_HASH_SET_SIZE = 1; public static final int DEFAULT_BUCKET_COUNT = 512; static final String MIN_HASH_TYPE = "MIN_HASH"; private final List<List<FixedSizeTreeSet<LongPair>>> minHashSets; private int hashSetSize = DEFAULT_HASH_SET_SIZE; private int bucketCount = DEFAULT_BUCKET_COUNT; private int hashCount = DEFAULT_HASH_COUNT; private boolean requiresInitialisation = true; private State endState; private int hashPosition = -1; private int bucketPosition = -1; private long bucketSize; private final boolean withRotation; private int endOffset; private boolean exhausted = false; private final CharTermAttribute termAttribute = addAttribute(CharTermAttribute.class); private final OffsetAttribute offsetAttribute = addAttribute(OffsetAttribute.class); private final TypeAttribute typeAttribute = addAttribute(TypeAttribute.class); private final PositionIncrementAttribute posIncAttribute = addAttribute(PositionIncrementAttribute.class); private final PositionLengthAttribute posLenAttribute = addAttribute(PositionLengthAttribute.class); static { for (int i = 0; i < HASH_CACHE_SIZE; i++) { cachedIntHashes[i] = new LongPair(); murmurhash3_x64_128(getBytes(i), 0, 4, 0, cachedIntHashes[i]); } } static byte[] getBytes(int i) { byte[] answer = new byte[4]; answer[3] = (byte) (i); answer[2] = (byte) (i >> 8); answer[1] = (byte) (i >> 16); answer[0] = (byte) (i >> 24); return answer; } /** * create a MinHash filter * * @param input the token stream * @param hashCount the no. of hashes * @param bucketCount the no. of buckets for hashing * @param hashSetSize the no. of min hashes to keep * @param withRotation whether rotate or not hashes while incrementing tokens */ public MinHashFilter(TokenStream input, int hashCount, int bucketCount, int hashSetSize, boolean withRotation) { super(input); if (hashCount <= 0) { throw new IllegalArgumentException("hashCount must be greater than zero"); } if (bucketCount <= 0) { throw new IllegalArgumentException("bucketCount must be greater than zero"); } if (hashSetSize <= 0) { throw new IllegalArgumentException("hashSetSize must be greater than zero"); } this.hashCount = hashCount; this.bucketCount = bucketCount; this.hashSetSize = hashSetSize; this.withRotation = withRotation; this.bucketSize = (1L << 32) / bucketCount; if((1L << 32) % bucketCount != 0) { bucketSize++; } minHashSets = new ArrayList<>(this.hashCount); for (int i = 0; i < this.hashCount; i++) { ArrayList<FixedSizeTreeSet<LongPair>> buckets = new ArrayList<>(this.bucketCount); minHashSets.add(buckets); for (int j = 0; j < this.bucketCount; j++) { FixedSizeTreeSet<LongPair> minSet = new FixedSizeTreeSet<>(this.hashSetSize); buckets.add(minSet); } } doRest(); } @Override public final boolean incrementToken() throws IOException { // Pull the underlying stream of tokens // Hash each token found // Generate the required number of variants of this hash // Keep the minimum hash value found so far of each variant int positionIncrement = 0; if (requiresInitialisation) { requiresInitialisation = false; boolean found = false; // First time through so we pull and hash everything while (input.incrementToken()) { found = true; String current = new String(termAttribute.buffer(), 0, termAttribute.length()); for (int i = 0; i < hashCount; i++) { byte[] bytes = current.getBytes("UTF-16LE"); LongPair hash = new LongPair(); murmurhash3_x64_128(bytes, 0, bytes.length, 0, hash); LongPair rehashed = combineOrdered(hash, getIntHash(i)); minHashSets.get(i).get((int) ((rehashed.val2 >>> 32) / bucketSize)).add(rehashed); } endOffset = offsetAttribute.endOffset(); } exhausted = true; input.end(); // We need the end state so an underlying shingle filter can have its state restored correctly. endState = captureState(); if (!found) { return false; } positionIncrement = 1; // fix up any wrap around bucket values. ... if (withRotation && (hashSetSize == 1)) { for (int hashLoop = 0; hashLoop < hashCount; hashLoop++) { for (int bucketLoop = 0; bucketLoop < bucketCount; bucketLoop++) { if (minHashSets.get(hashLoop).get(bucketLoop).size() == 0) { for (int bucketOffset = 1; bucketOffset < bucketCount; bucketOffset++) { if (minHashSets.get(hashLoop).get((bucketLoop + bucketOffset) % bucketCount).size() > 0) { LongPair replacementHash = minHashSets.get(hashLoop).get((bucketLoop + bucketOffset) % bucketCount) .first(); minHashSets.get(hashLoop).get(bucketLoop).add(replacementHash); break; } } } } } } } clearAttributes(); while (hashPosition < hashCount) { if (hashPosition == -1) { hashPosition++; } else { while (bucketPosition < bucketCount) { if (bucketPosition == -1) { bucketPosition++; } else { LongPair hash = minHashSets.get(hashPosition).get(bucketPosition).pollFirst(); if (hash != null) { termAttribute.setEmpty(); if (hashCount > 1) { termAttribute.append(int0(hashPosition)); termAttribute.append(int1(hashPosition)); } long high = hash.val2; termAttribute.append(long0(high)); termAttribute.append(long1(high)); termAttribute.append(long2(high)); termAttribute.append(long3(high)); long low = hash.val1; termAttribute.append(long0(low)); termAttribute.append(long1(low)); if (hashCount == 1) { termAttribute.append(long2(low)); termAttribute.append(long3(low)); } posIncAttribute.setPositionIncrement(positionIncrement); offsetAttribute.setOffset(0, endOffset); typeAttribute.setType(MIN_HASH_TYPE); posLenAttribute.setPositionLength(1); return true; } else { bucketPosition++; } } } bucketPosition = -1; hashPosition++; } } return false; } private static LongPair getIntHash(int i) { if (i < HASH_CACHE_SIZE) { return cachedIntHashes[i]; } else { LongPair answer = new LongPair(); murmurhash3_x64_128(getBytes(i), 0, 4, 0, answer); return answer; } } @Override public void end() throws IOException { if(!exhausted) { input.end(); } restoreState(endState); } @Override public void reset() throws IOException { super.reset(); doRest(); } private void doRest() { for (int i = 0; i < hashCount; i++) { for (int j = 0; j < bucketCount; j++) { minHashSets.get(i).get(j).clear(); } } endState = null; hashPosition = -1; bucketPosition = -1; requiresInitialisation = true; exhausted = false; } private static char long0(long x) { return (char) (x >> 48); } private static char long1(long x) { return (char) (x >> 32); } private static char long2(long x) { return (char) (x >> 16); } private static char long3(long x) { return (char) (x); } private static char int0(int x) { return (char) (x >> 16); } private static char int1(int x) { return (char) (x); } static boolean isLessThanUnsigned(long n1, long n2) { return (n1 < n2) ^ ((n1 < 0) != (n2 < 0)); } static class FixedSizeTreeSet<E extends Comparable<E>> extends TreeSet<E> { /** * */ private static final long serialVersionUID = -8237117170340299630L; private final int capacity; FixedSizeTreeSet() { this(20); } FixedSizeTreeSet(int capacity) { super(); this.capacity = capacity; } @Override public boolean add(final E toAdd) { if (capacity <= size()) { final E lastElm = last(); if (toAdd.compareTo(lastElm) > -1) { return false; } else { pollLast(); } } return super.add(toAdd); } } private static LongPair combineOrdered(LongPair... hashCodes) { LongPair result = new LongPair(); for (LongPair hashCode : hashCodes) { result.val1 = result.val1 * 37 + hashCode.val1; result.val2 = result.val2 * 37 + hashCode.val2; } return result; } /** 128 bits of state */ static final class LongPair implements Comparable<LongPair> { public long val1; public long val2; /* * (non-Javadoc) * * @see java.lang.Comparable#compareTo(java.lang.Object) */ @Override public int compareTo(LongPair other) { if (isLessThanUnsigned(val2, other.val2)) { return -1; } else if (val2 == other.val2) { if (isLessThanUnsigned(val1, other.val1)) { return -1; } else if (val1 == other.val1) { return 0; } else { return 1; } } else { return 1; } } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; LongPair longPair = (LongPair) o; return val1 == longPair.val1 && val2 == longPair.val2; } @Override public int hashCode() { int result = (int) (val1 ^ (val1 >>> 32)); result = 31 * result + (int) (val2 ^ (val2 >>> 32)); return result; } } /** Gets a long from a byte buffer in little endian byte order. */ private static long getLongLittleEndian(byte[] buf, int offset) { return ((long) buf[offset + 7] << 56) // no mask needed | ((buf[offset + 6] & 0xffL) << 48) | ((buf[offset + 5] & 0xffL) << 40) | ((buf[offset + 4] & 0xffL) << 32) | ((buf[offset + 3] & 0xffL) << 24) | ((buf[offset + 2] & 0xffL) << 16) | ((buf[offset + 1] & 0xffL) << 8) | ((buf[offset] & 0xffL)); // no shift needed } /** Returns the MurmurHash3_x64_128 hash, placing the result in "out". */ @SuppressWarnings("fallthrough") // the huge switch is designed to use fall through into cases! static void murmurhash3_x64_128(byte[] key, int offset, int len, int seed, LongPair out) { // The original algorithm does have a 32 bit unsigned seed. // We have to mask to match the behavior of the unsigned types and prevent sign extension. long h1 = seed & 0x00000000FFFFFFFFL; long h2 = seed & 0x00000000FFFFFFFFL; final long c1 = 0x87c37b91114253d5L; final long c2 = 0x4cf5ad432745937fL; int roundedEnd = offset + (len & 0xFFFFFFF0); // round down to 16 byte block for (int i = offset; i < roundedEnd; i += 16) { long k1 = getLongLittleEndian(key, i); long k2 = getLongLittleEndian(key, i + 8); k1 *= c1; k1 = Long.rotateLeft(k1, 31); k1 *= c2; h1 ^= k1; h1 = Long.rotateLeft(h1, 27); h1 += h2; h1 = h1 * 5 + 0x52dce729; k2 *= c2; k2 = Long.rotateLeft(k2, 33); k2 *= c1; h2 ^= k2; h2 = Long.rotateLeft(h2, 31); h2 += h1; h2 = h2 * 5 + 0x38495ab5; } long k1 = 0; long k2 = 0; switch (len & 15) { case 15: k2 = (key[roundedEnd + 14] & 0xffL) << 48; case 14: k2 |= (key[roundedEnd + 13] & 0xffL) << 40; case 13: k2 |= (key[roundedEnd + 12] & 0xffL) << 32; case 12: k2 |= (key[roundedEnd + 11] & 0xffL) << 24; case 11: k2 |= (key[roundedEnd + 10] & 0xffL) << 16; case 10: k2 |= (key[roundedEnd + 9] & 0xffL) << 8; case 9: k2 |= (key[roundedEnd + 8] & 0xffL); k2 *= c2; k2 = Long.rotateLeft(k2, 33); k2 *= c1; h2 ^= k2; case 8: k1 = ((long) key[roundedEnd + 7]) << 56; case 7: k1 |= (key[roundedEnd + 6] & 0xffL) << 48; case 6: k1 |= (key[roundedEnd + 5] & 0xffL) << 40; case 5: k1 |= (key[roundedEnd + 4] & 0xffL) << 32; case 4: k1 |= (key[roundedEnd + 3] & 0xffL) << 24; case 3: k1 |= (key[roundedEnd + 2] & 0xffL) << 16; case 2: k1 |= (key[roundedEnd + 1] & 0xffL) << 8; case 1: k1 |= (key[roundedEnd] & 0xffL); k1 *= c1; k1 = Long.rotateLeft(k1, 31); k1 *= c2; h1 ^= k1; } // ---------- // finalization h1 ^= len; h2 ^= len; h1 += h2; h2 += h1; h1 = fmix64(h1); h2 = fmix64(h2); h1 += h2; h2 += h1; out.val1 = h1; out.val2 = h2; } private static long fmix64(long k) { k ^= k >>> 33; k *= 0xff51afd7ed558ccdL; k ^= k >>> 33; k *= 0xc4ceb9fe1a85ec53L; k ^= k >>> 33; return k; } }