/* * Copyright 2004-2014 H2 Group. Multiple-Licensed under the MPL 2.0, * and the EPL 1.0 (http://h2database.com/html/license.html). * Initial Developer: H2 Group */ package org.h2.dev.hash; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.Charset; import java.security.SecureRandom; import java.util.ArrayList; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.zip.Deflater; import java.util.zip.Inflater; /** * A minimal perfect hash function tool. It needs about 1.98 bits per key. * <p> * The algorithm is recursive: sets that contain no or only one entry are not * processed as no conflicts are possible. For sets that contain between 2 and * 12 entries, a number of hash functions are tested to check if they can store * the data without conflict. If no function was found, and for larger sets, the * set is split into a (possibly high) number of smaller set, which are * processed recursively. The average size of a top-level bucket is about 216 * entries, and the maximum recursion level is typically 5. * <p> * At the end of the generation process, the data is compressed using a general * purpose compression tool (Deflate / Huffman coding) down to 2.0 bits per key. * The uncompressed data is around 2.2 bits per key. With arithmetic coding, * about 1.9 bits per key are needed. Generating the hash function takes about * 2.5 seconds per million keys with 8 cores (multithreaded). The algorithm * automatically scales with the number of available CPUs (using as many threads * as there are processors). At the expense of processing time, a lower number * of bits per key would be possible (for example 1.84 bits per key with 100000 * keys, using 32 seconds generation time, with Huffman coding). * <p> * The memory usage to efficiently calculate hash values is around 2.5 bits per * key (the space needed for the uncompressed description, plus 8 bytes for * every top-level bucket). * <p> * At each level, only one user defined hash function per object is called * (about 3 hash functions per key). The result is further processed using a * supplemental hash function, so that the default user defined hash function * doesn't need to be sophisticated (it doesn't need to be non-linear, have a * good avalanche effect, or generate random looking data; it just should * produce few conflicts if possible). * <p> * To protect against hash flooding and similar attacks, a secure random seed * per hash table is used. For further protection, cryptographically secure * functions such as SipHash or SHA-256 can be used. However, such (slower) * functions only need to be used if regular hash functions produce too many * conflicts. This case is detected when generating the perfect hash function, * by checking if there are too many conflicts (more than 2160 entries in one * top-level bucket). In this case, the next hash function is used. That way, in * the normal case, where no attack is happening, only fast, but less secure, * hash functions are called. It is fine to use the regular hashCode method as * the level 0 hash function. However, just relying on the regular hashCode * method does not work if the key has more than 32 bits, because the risk of * collisions is too high. Incorrect universal hash functions are detected (an * exception is thrown if there are more than 32 recursion levels). * <p> * In-place updating of the hash table is not implemented but possible in * theory, by patching the hash function description. With a small change, * non-minimal perfect hash functions can be calculated (for example 1.22 bits * per key at a fill rate of 81%). * * @param <K> the key type */ public class MinimalPerfectHash<K> { /** * Large buckets are typically divided into buckets of this size. */ private static final int DIVIDE = 6; /** * For sets larger than this, instead of trying to map then uniquely to a * set of the same size, the size of the set is incremented by one. This * reduces the time to find a mapping, but the index of the hole also needs * to be stored, which increases the space usage. */ private static final int SPEEDUP = 11; /** * The maximum size of a small bucket (one that is not further split if * possible). */ private static final int MAX_SIZE = 14; /** * The maximum offset for hash functions of small buckets. At most that many * hash functions are tried for the given size. */ private static final int[] MAX_OFFSETS = { 0, 0, 8, 18, 47, 123, 319, 831, 2162, 5622, 14617, 38006, 98815, 256920, 667993 }; /** * The output value to split the bucket into many (more than 2) smaller * buckets. */ private static final int SPLIT_MANY = 3; /** * The minimum output value for a small bucket of a given size. */ private static final int[] SIZE_OFFSETS = new int[MAX_OFFSETS.length + 1]; /** * A secure random generator. */ private static final SecureRandom RANDOM = new SecureRandom(); static { for (int i = SPEEDUP; i < MAX_OFFSETS.length; i++) { MAX_OFFSETS[i] = (int) (MAX_OFFSETS[i] * 2.5); } int last = SPLIT_MANY + 1; for (int i = 0; i < MAX_OFFSETS.length; i++) { SIZE_OFFSETS[i] = last; last += MAX_OFFSETS[i]; } SIZE_OFFSETS[SIZE_OFFSETS.length - 1] = last; } /** * The universal hash function. */ private final UniversalHash<K> hash; /** * The description of the hash function. Used for calculating the hash of a * key. */ private final byte[] data; /** * The random seed. */ private final int seed; /** * The size up to the given root-level bucket in the data array. Used to * speed up calculating the hash of a key. */ private final int[] rootSize; /** * The position of the given root-level bucket in the data array. Used to * speed up calculating the hash of a key. */ private final int[] rootPos; /** * The hash function level at the root of the tree. Typically 0, except if * the hash function at that level didn't split the entries as expected * (which can be due to a bad hash function, or due to an attack). */ private final int rootLevel; /** * Create a hash object to convert keys to hashes. * * @param desc the data returned by the generate method * @param hash the universal hash function */ public MinimalPerfectHash(byte[] desc, UniversalHash<K> hash) { this.hash = hash; byte[] b = data = expand(desc); seed = ((b[0] & 255) << 24) | ((b[1] & 255) << 16) | ((b[2] & 255) << 8) | (b[3] & 255); if (b[4] == SPLIT_MANY) { rootLevel = b[b.length - 1] & 255; int split = readVarInt(b, 5); rootSize = new int[split]; rootPos = new int[split]; int pos = 5 + getVarIntLength(b, 5); int sizeSum = 0; for (int i = 0; i < split; i++) { rootSize[i] = sizeSum; rootPos[i] = pos; int start = pos; pos = getNextPos(pos); sizeSum += getSizeSum(start, pos); } } else { rootLevel = 0; rootSize = null; rootPos = null; } } /** * Calculate the hash value for the given key. * * @param x the key * @return the hash value */ public int get(K x) { return get(4, x, true, rootLevel); } /** * Get the hash value for the given key, starting at a certain position and * level. * * @param pos the start position * @param x the key * @param isRoot whether this is the root of the tree * @param level the level * @return the hash value */ private int get(int pos, K x, boolean isRoot, int level) { int n = readVarInt(data, pos); if (n < 2) { return 0; } else if (n > SPLIT_MANY) { int size = getSize(n); int offset = getOffset(n, size); if (size >= SPEEDUP) { int p = offset % (size + 1); offset = offset / (size + 1); int result = hash(x, hash, level, seed, offset, size + 1); if (result >= p) { result--; } return result; } return hash(x, hash, level, seed, offset, size); } pos++; int split; if (n == SPLIT_MANY) { split = readVarInt(data, pos); pos += getVarIntLength(data, pos); } else { split = n; } int h = hash(x, hash, level, seed, 0, split); int s; if (isRoot && rootPos != null) { s = rootSize[h]; pos = rootPos[h]; } else { int start = pos; for (int i = 0; i < h; i++) { pos = getNextPos(pos); } s = getSizeSum(start, pos); } return s + get(pos, x, false, level + 1); } /** * Get the position of the next sibling. * * @param pos the position of this branch * @return the position of the next sibling */ private int getNextPos(int pos) { int n = readVarInt(data, pos); pos += getVarIntLength(data, pos); if (n < 2 || n > SPLIT_MANY) { return pos; } int split; if (n == SPLIT_MANY) { split = readVarInt(data, pos); pos += getVarIntLength(data, pos); } else { split = n; } for (int i = 0; i < split; i++) { pos = getNextPos(pos); } return pos; } /** * The sum of the sizes between the start and end position. * * @param start the start position * @param end the end position (excluding) * @return the sizes */ private int getSizeSum(int start, int end) { int s = 0; for (int pos = start; pos < end;) { int n = readVarInt(data, pos); pos += getVarIntLength(data, pos); if (n < 2) { s += n; } else if (n > SPLIT_MANY) { s += getSize(n); } else if (n == SPLIT_MANY) { pos += getVarIntLength(data, pos); } } return s; } private static void writeSizeOffset(ByteArrayOutputStream out, int size, int offset) { writeVarInt(out, SIZE_OFFSETS[size] + offset); } private static int getOffset(int n, int size) { return n - SIZE_OFFSETS[size]; } private static int getSize(int n) { for (int i = 0; i < SIZE_OFFSETS.length; i++) { if (n < SIZE_OFFSETS[i]) { return i - 1; } } return 0; } /** * Generate the minimal perfect hash function data from the given set of * integers. * * @param set the data * @param hash the universal hash function * @return the hash function description */ public static <K> byte[] generate(Set<K> set, UniversalHash<K> hash) { ArrayList<K> list = new ArrayList<K>(); list.addAll(set); ByteArrayOutputStream out = new ByteArrayOutputStream(); int seed = RANDOM.nextInt(); out.write(seed >>> 24); out.write(seed >>> 16); out.write(seed >>> 8); out.write(seed); generate(list, hash, 0, seed, out); return compress(out.toByteArray()); } /** * Generate the perfect hash function data from the given set of integers. * * @param list the data, in the form of a list * @param hash the universal hash function * @param level the recursion level * @param seed the random seed * @param out the output stream */ static <K> void generate(ArrayList<K> list, UniversalHash<K> hash, int level, int seed, ByteArrayOutputStream out) { int size = list.size(); if (size <= 1) { out.write(size); return; } if (level > 32) { throw new IllegalStateException("Too many recursions; " + " incorrect universal hash function?"); } if (size <= MAX_SIZE) { int maxOffset = MAX_OFFSETS[size]; // get the hash codes - we could stop early // if we detect that two keys have the same hash int[] hashes = new int[size]; for (int i = 0; i < size; i++) { hashes[i] = hash.hashCode(list.get(i), level, seed); } // use the supplemental hash function to find a way // to make the hash code unique within this group - // there might be a much faster way than that, by // checking which bits of the hash code matter most int testSize = size; if (size >= SPEEDUP) { testSize++; maxOffset /= testSize; } nextOffset: for (int offset = 0; offset < maxOffset; offset++) { int bits = 0; for (int i = 0; i < size; i++) { int x = hashes[i]; int h = hash(x, level, offset, testSize); if ((bits & (1 << h)) != 0) { continue nextOffset; } bits |= 1 << h; } if (size >= SPEEDUP) { int pos = Integer.numberOfTrailingZeros(~bits); writeSizeOffset(out, size, offset * (size + 1) + pos); } else { writeSizeOffset(out, size, offset); } return; } } int split; if (size > 57 * DIVIDE) { split = size / (36 * DIVIDE); } else { split = (size - 47) / DIVIDE; } split = Math.max(2, split); boolean isRoot = level == 0; ArrayList<ArrayList<K>> lists; do { lists = new ArrayList<ArrayList<K>>(split); for (int i = 0; i < split; i++) { lists.add(new ArrayList<K>(size / split)); } for (int i = 0; i < size; i++) { K x = list.get(i); ArrayList<K> l = lists.get(hash(x, hash, level, seed, 0, split)); l.add(x); if (isRoot && split >= SPLIT_MANY && l.size() > 36 * DIVIDE * 10) { // a bad hash function or attack was detected level++; lists = null; break; } } } while (lists == null); if (split >= SPLIT_MANY) { out.write(SPLIT_MANY); } writeVarInt(out, split); boolean multiThreaded = isRoot && list.size() > 1000; list.clear(); list.trimToSize(); if (multiThreaded) { generateMultiThreaded(lists, hash, level, seed, out); } else { for (ArrayList<K> s2 : lists) { generate(s2, hash, level + 1, seed, out); } } if (isRoot && split >= SPLIT_MANY) { out.write(level); } } private static <K> void generateMultiThreaded( final ArrayList<ArrayList<K>> lists, final UniversalHash<K> hash, final int level, final int seed, ByteArrayOutputStream out) { final ArrayList<ByteArrayOutputStream> outList = new ArrayList<ByteArrayOutputStream>(); int processors = Runtime.getRuntime().availableProcessors(); Thread[] threads = new Thread[processors]; final AtomicInteger success = new AtomicInteger(); final AtomicReference<Exception> failure = new AtomicReference<Exception>(); for (int i = 0; i < processors; i++) { threads[i] = new Thread() { @Override public void run() { try { while (true) { ArrayList<K> list; ByteArrayOutputStream temp = new ByteArrayOutputStream(); synchronized (lists) { if (lists.isEmpty()) { break; } list = lists.remove(0); outList.add(temp); } generate(list, hash, level + 1, seed, temp); } } catch (Exception e) { failure.set(e); return; } success.incrementAndGet(); } }; } for (Thread t : threads) { t.start(); } try { for (Thread t : threads) { t.join(); } if (success.get() != threads.length) { Exception e = failure.get(); if (e != null) { throw new RuntimeException(e); } throw new RuntimeException("Unknown failure in one thread"); } for (ByteArrayOutputStream temp : outList) { out.write(temp.toByteArray()); } } catch (InterruptedException e) { throw new RuntimeException(e); } catch (IOException e) { throw new RuntimeException(e); } } /** * Calculate the hash of a key. The result depends on the key, the recursion * level, and the offset. * * @param o the key * @param level the recursion level * @param seed the random seed * @param offset the index of the hash function * @param size the size of the bucket * @return the hash (a value between 0, including, and the size, excluding) */ private static <K> int hash(K o, UniversalHash<K> hash, int level, int seed, int offset, int size) { int x = hash.hashCode(o, level, seed); x += level + offset * 32; x = ((x >>> 16) ^ x) * 0x45d9f3b; x = ((x >>> 16) ^ x) * 0x45d9f3b; x = (x >>> 16) ^ x; return (x & (-1 >>> 1)) % size; } private static <K> int hash(int x, int level, int offset, int size) { x += level + offset * 32; x = ((x >>> 16) ^ x) * 0x45d9f3b; x = ((x >>> 16) ^ x) * 0x45d9f3b; x = (x >>> 16) ^ x; return (x & (-1 >>> 1)) % size; } private static int writeVarInt(ByteArrayOutputStream out, int x) { int len = 0; while ((x & ~0x7f) != 0) { out.write((byte) (0x80 | (x & 0x7f))); x >>>= 7; len++; } out.write((byte) x); return ++len; } private static int readVarInt(byte[] d, int pos) { int x = d[pos++]; if (x >= 0) { return x; } x &= 0x7f; for (int s = 7; s < 64; s += 7) { int b = d[pos++]; x |= (b & 0x7f) << s; if (b >= 0) { break; } } return x; } private static int getVarIntLength(byte[] d, int pos) { int x = d[pos++]; if (x >= 0) { return 1; } int len = 2; for (int s = 7; s < 64; s += 7) { int b = d[pos++]; if (b >= 0) { break; } len++; } return len; } /** * Compress the hash description using a Huffman coding. * * @param d the data * @return the compressed data */ private static byte[] compress(byte[] d) { Deflater deflater = new Deflater(); deflater.setStrategy(Deflater.HUFFMAN_ONLY); deflater.setInput(d); deflater.finish(); ByteArrayOutputStream out2 = new ByteArrayOutputStream(d.length); byte[] buffer = new byte[1024]; while (!deflater.finished()) { int count = deflater.deflate(buffer); out2.write(buffer, 0, count); } deflater.end(); return out2.toByteArray(); } /** * Decompress the hash description using a Huffman coding. * * @param d the data * @return the decompressed data */ private static byte[] expand(byte[] d) { Inflater inflater = new Inflater(); inflater.setInput(d); ByteArrayOutputStream out = new ByteArrayOutputStream(d.length); byte[] buffer = new byte[1024]; try { while (!inflater.finished()) { int count = inflater.inflate(buffer); out.write(buffer, 0, count); } inflater.end(); } catch (Exception e) { throw new IllegalArgumentException(e); } return out.toByteArray(); } /** * An interface that can calculate multiple hash values for an object. The * returned hash value of two distinct objects may be the same for a given * hash function index, but as more hash functions indexes are called for * those objects, the returned value must eventually be different. * <p> * The returned value does not need to be uniformly distributed. * * @param <T> the type */ public interface UniversalHash<T> { /** * Calculate the hash of the given object. * * @param o the object * @param index the hash function index (index 0 is used first, so the * method should be very fast with index 0; index 1 and so on * are only called when really needed) * @param seed the random seed (always the same for a hash table) * @return the hash value */ int hashCode(T o, int index, int seed); } /** * A sample hash implementation for long keys. */ public static class LongHash implements UniversalHash<Long> { @Override public int hashCode(Long o, int index, int seed) { if (index == 0) { return o.hashCode(); } else if (index < 8) { long x = o.longValue(); x += index; x = ((x >>> 32) ^ x) * 0x45d9f3b; x = ((x >>> 32) ^ x) * 0x45d9f3b; return (int) (x ^ (x >>> 32)); } // get the lower or higher 32 bit depending on the index int shift = (index & 1) * 32; return (int) (o.longValue() >>> shift); } } /** * A sample hash implementation for integer keys. */ public static class StringHash implements UniversalHash<String> { private static final Charset UTF8 = Charset.forName("UTF-8"); @Override public int hashCode(String o, int index, int seed) { if (index == 0) { // use the default hash of a string, which might already be // available return o.hashCode(); } else if (index < 8) { // use a different hash function, which is fast but not // necessarily universal, and not cryptographically secure return getFastHash(o, index, seed); } // this method is supposed to be cryptographically secure; // we could use SHA-256 for higher indexes return getSipHash24(o, index, seed); } /** * A cryptographically weak hash function. It is supposed to be fast. * * @param o the string * @param index the hash function index * @param seed the seed * @return the hash value */ public static int getFastHash(String o, int index, int seed) { int x = (index * 0x9f3b) ^ seed; int result = seed + o.length(); for (int i = 0; i < o.length(); i++) { x = 31 + x * 0x9f3b; result ^= x * (1 + o.charAt(i)); } return result; } /** * A cryptographically relatively secure hash function. It is supposed * to protected against hash-flooding denial-of-service attacks. * * @param o the string * @param k0 key 0 * @param k1 key 1 * @return the hash value */ public static int getSipHash24(String o, long k0, long k1) { byte[] b = o.getBytes(UTF8); return getSipHash24(b, 0, b.length, k0, k1); } /** * A cryptographically relatively secure hash function. It is supposed * to protected against hash-flooding denial-of-service attacks. * * @param b the data * @param start the start position * @param end the end position plus one * @param k0 key 0 * @param k1 key 1 * @return the hash value */ public static int getSipHash24(byte[] b, int start, int end, long k0, long k1) { long v0 = k0 ^ 0x736f6d6570736575L; long v1 = k1 ^ 0x646f72616e646f6dL; long v2 = k0 ^ 0x6c7967656e657261L; long v3 = k1 ^ 0x7465646279746573L; int repeat; for (int off = start; off <= end + 8; off += 8) { long m; if (off <= end) { m = 0; int i = 0; for (; i < 8 && off + i < end; i++) { m |= ((long) b[off + i] & 255) << (8 * i); } if (i < 8) { m |= ((long) end - start) << 56; } v3 ^= m; repeat = 2; } else { m = 0; v2 ^= 0xff; repeat = 4; } for (int i = 0; i < repeat; i++) { v0 += v1; v2 += v3; v1 = Long.rotateLeft(v1, 13); v3 = Long.rotateLeft(v3, 16); v1 ^= v0; v3 ^= v2; v0 = Long.rotateLeft(v0, 32); v2 += v1; v0 += v3; v1 = Long.rotateLeft(v1, 17); v3 = Long.rotateLeft(v3, 21); v1 ^= v2; v3 ^= v0; v2 = Long.rotateLeft(v2, 32); } v0 ^= m; } return (int) (v0 ^ v1 ^ v2 ^ v3); } } }