/** * */ package com.ganji.as.thrift.protocol.cluster.load.balance; /** * @author yikangfeng * @date 2015年7月22日 */ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.Random; import java.util.SortedMap; import java.util.TreeMap; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import com.ganji.as.thrift.protocol.client.request.ThriftClientInvocation; import com.ganji.as.thrift.protocol.server.nodes.discovery.ServerNode; import com.ganji.as.thrift.protocol.server.nodes.discovery.ServerNodeInfo; import com.ganji.as.thrift.protocol.server.nodes.discovery.ServiceEndpoint; public class ConsistentHashLoadBalance extends AbstractLoadBalance { final private ConcurrentMap<String, ConsistentHashSelector<ServerNode>> selectors = new ConcurrentHashMap<>(); final private Random random_ = new Random(); final private HashFunction hashFunction = new MD5HashFunction(); @Override ServerNode doSelect(final List<ServerNode> serverNodes, final ThriftClientInvocation clientInvocation) { // TODO Auto-generated method stub final String key = serverNodes.get(random_.nextInt(serverNodes.size())) .toString(); final int identityHashCode = System.identityHashCode(serverNodes); ConsistentHashSelector<ServerNode> selector = selectors.get(key); if (selector == null || selector.getIdentityHashCode() != identityHashCode) { selectors.put(key, new ConsistentHashSelector<ServerNode>( hashFunction, 160, serverNodes)); selector = selectors.get(key); } return selector.get(Arrays.toString(clientInvocation.getMessage())); } private static class ConsistentHashSelector<T> { private final HashFunction hashFunction; private final int identityHashCode; private final int numberOfReplicas; // 虚拟节点 private final SortedMap<Long, T> circle = new TreeMap<Long, T>(); // 用来存储虚拟节点hash值 // 到真实node的映射 public ConsistentHashSelector(HashFunction hashFunction, int numberOfReplicas, Collection<T> nodes) { this.hashFunction = hashFunction; this.numberOfReplicas = numberOfReplicas; this.identityHashCode = System.identityHashCode(nodes); for (T node : nodes) { add(node); } } public int getIdentityHashCode() { return identityHashCode; } public void add(T node) { for (int i = 0; i < numberOfReplicas; i++) { circle.put(hashFunction.hash(node.toString() + i), node); } } @SuppressWarnings("unused") public void remove(T node) { for (int i = 0; i < numberOfReplicas; i++) { circle.remove(hashFunction.hash(node.toString() + i)); } } /** * 获得一个最近的顺时针节点 * * @param key * 为给定键取Hash,取得顺时针方向上最近的一个虚拟节点对应的实际节点 * @return */ public T get(Object key) { if (circle.isEmpty()) { return null; } long hash = hashFunction.hash((String) key); if (!circle.containsKey(hash)) { SortedMap<Long, T> tailMap = circle.tailMap(hash); // //返回此映射的部分视图,其键大于等于 // hash hash = tailMap.isEmpty() ? circle.firstKey() : tailMap .firstKey(); } return circle.get(hash); } @SuppressWarnings("unused") public long getSize() { return circle.size(); } } static public void main(String[] args) { final LoadBalance loadBalance = new ConsistentHashLoadBalance(); List<ServerNode> serverNodes = new ArrayList<>(); ServerNodeInfo serverNodeInfo1 = new ServerNodeInfo(); ServiceEndpoint serviceEndpoint1 = new ServiceEndpoint(); serviceEndpoint1.setHost("host1"); serviceEndpoint1.setPort(8081); serverNodeInfo1.setServiceEndpoint(serviceEndpoint1); ServerNodeInfo serverNodeInfo2 = new ServerNodeInfo(); ServiceEndpoint serviceEndpoint2 = new ServiceEndpoint(); serviceEndpoint2.setHost("host2"); serviceEndpoint2.setPort(8082); serverNodeInfo2.setServiceEndpoint(serviceEndpoint2); ServerNodeInfo serverNodeInfo3 = new ServerNodeInfo(); ServiceEndpoint serviceEndpoint3 = new ServiceEndpoint(); serviceEndpoint3.setHost("host3"); serviceEndpoint3.setPort(8083); serverNodeInfo3.setServiceEndpoint(serviceEndpoint3); serverNodes.add(serverNodeInfo1); serverNodes.add(serverNodeInfo2); serverNodes.add(serverNodeInfo3); for (int i = 0; i < 1000; ++i) { try { System.out.println(loadBalance.select(serverNodes, null)); } catch (Throwable e) { // TODO Auto-generated catch block e.printStackTrace(); } } } }