package com.bazaarvoice.ostrich.partition; import com.bazaarvoice.ostrich.PartitionContext; import com.bazaarvoice.ostrich.ServiceEndPoint; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.google.common.hash.HashCode; import com.google.common.hash.Hasher; import com.google.common.hash.Hashing; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.NavigableMap; /** * Uses consistent hashing to map service calls to end points. Partitions are mapped to servers based on hashes of the * service end point ID strings (ie. ip:port). * <p/> * Choose this partition filter when every server can handle every request, but throughput is increased if requests * on the same data are directed to the same server. For example, choose this partition filter to distribute requests * across a set of memcached servers. * <p/> * The algorithm is inspired by: * <a href="http://last.fm/user/RJ/journal/2007/04/10/rz_libketama_-_a_consistent_hashing_algo_for_memcache_clients"> * libketama</a> */ public class ConsistentHashPartitionFilter implements PartitionFilter { private static final int DEFAULT_ENTRIES_PER_END_POINT = 100; private final int _entriesPerEndPoint; private final List<String> _partitionKeys; private final NavigableMap<Integer, String> _ring = Maps.newTreeMap(); private Map<String, ServiceEndPoint> _endPointsById = Maps.newHashMap(); /** * Constructs a default {@code ConsistentHashPartitionFilter} that uses the default partition key * ({@link com.bazaarvoice.ostrich.PartitionContext#get()}) to determine the partition. */ public ConsistentHashPartitionFilter() { this(Collections.<String>emptyList()); } /** * Constructs a {@code ConsistentHashPartitionFilter} that concatenates the partition context values for the * specified set of keys to determine the partition. */ public ConsistentHashPartitionFilter(String... partitionKeys) { this(Arrays.asList(partitionKeys)); } /** * Constructs a {@code ConsistentHashPartitionFilter} that concatenates the partition context values for the * specified set of keys to determine the partition. */ public ConsistentHashPartitionFilter(List<String> partitionKeys) { this(partitionKeys, DEFAULT_ENTRIES_PER_END_POINT); } /** * Constructs a {@code ConsistentHashPartitionFilter} that concatenates the partition context values for the * specified set of keys to determine the partition. */ private ConsistentHashPartitionFilter(List<String> partitionKeys, int entriesPerEndPoint) { _partitionKeys = partitionKeys; _entriesPerEndPoint = entriesPerEndPoint; } @Override public Iterable<ServiceEndPoint> filter(Iterable<ServiceEndPoint> endPoints, PartitionContext partitionContext) { HashCode partitionHash = getPartitionHash(partitionContext); if (partitionHash == null) { return endPoints; // No partition hash means any server can handle the request. } // The choose() method is synchronized. Do any prep work we can up front before calling into it. Map<String, ServiceEndPoint> endPointsById = indexById(endPoints); ServiceEndPoint endPoint = choose(endPointsById, partitionHash); return Collections.singleton(endPoint); } private HashCode getPartitionHash(PartitionContext partitionContext) { // The precise implementation of this method isn't particularly important. There are lots of ways we can hash // the data in the PartitionContext. It just needs to be deterministic and to take into account the values in // the PartitionContext for the configured partition keys. Hasher hasher = Hashing.md5().newHasher(); boolean empty = true; if (_partitionKeys.isEmpty()) { // Use the default context. Object value = partitionContext.get(); if (value != null) { putUnencodedChars(hasher, value.toString()); empty = false; } } for (String partitionKey : _partitionKeys) { Object value = partitionContext.get(partitionKey); if (value != null) { // Include both the key and value in the hash so "reviewId" of 1 and "reviewerId" of 1 hash differently. putUnencodedChars(hasher, partitionKey); putUnencodedChars(hasher, value.toString()); empty = false; } } if (empty) { // When the partition context has no relevant values that means we should ignore the partition context and // don't filter the end points based on partition. Return null to indicate this. return null; } return hasher.hash(); } private synchronized ServiceEndPoint choose(Map<String, ServiceEndPoint> endPointsById, HashCode partitionHash) { // Update the ring if the set of active end points has changed. for (String endPointId : Sets.difference(_endPointsById.keySet(), endPointsById.keySet())) { for (Integer hash : computeHashCodes(endPointId)) { _ring.remove(hash); } } for (String endPointId : Sets.difference(endPointsById.keySet(), _endPointsById.keySet())) { for (Integer hash : computeHashCodes(endPointId)) { _ring.put(hash, endPointId); } } if (!_endPointsById.equals(endPointsById)) { _endPointsById = endPointsById; } // For the given partition hash, find its location in the ring and return its associated end point. Map.Entry<Integer, String> entry = _ring.ceilingEntry(partitionHash.asInt()); if (entry == null) { entry = _ring.firstEntry(); } return _endPointsById.get(entry.getValue()); } /** * Returns a list of pseudo-random 32-bit values derived from the specified end point ID. */ private List<Integer> computeHashCodes(String endPointId) { // Use the libketama approach of using MD5 hashes to generate 32-bit random values. This assigns a set of // randomly generated ranges to each end point. The individual ranges may vary widely in size, but, with // sufficient # of entries per end point, the overall amount of data assigned to each server tends to even out // with minimal variation (256 entries per server yields roughly 5% variation in server load). List<Integer> list = Lists.newArrayListWithCapacity(_entriesPerEndPoint); for (int i = 0; list.size() < _entriesPerEndPoint; i++) { Hasher hasher = Hashing.md5().newHasher(); hasher.putInt(i); putUnencodedChars(hasher, endPointId); ByteBuffer buf = ByteBuffer.wrap(hasher.hash().asBytes()); while (buf.hasRemaining() && list.size() < _entriesPerEndPoint) { list.add(buf.getInt()); } } return list; } /** * Returns a map of {@link ServiceEndPoint} objects indexed by their ID. */ private Map<String, ServiceEndPoint> indexById(Iterable<ServiceEndPoint> endPoints) { Map<String, ServiceEndPoint> map = Maps.newHashMap(); for (ServiceEndPoint endPoint : endPoints) { map.put(endPoint.getId(), endPoint); } return map; } private void putUnencodedChars(Hasher hasher, CharSequence charSequence) { // This is equivalent to Guava 15.0+'s Hasher.putUnencodedChars(CharSequence) but is backward compatible to // Guava 11.0-14.0.1 in which it was called Hasher.putString(CharSequence). for (int i = 0; i < charSequence.length(); i++) { hasher.putChar(charSequence.charAt(i)); } } }