/** * Copyright 2012 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package net.sf.katta.master; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import net.sf.katta.util.CircularList; import org.apache.log4j.Logger; /** * Don't replicate shards across nodes that live on the same host. (extension of DefaultDistributionPolicy) * Parses the node name as hostname:port * Algorithm explanation: * For a given shard, select hosts that don't have that shard. Pick node on each host with fewest number of shards. * If these are exhausted, then start using the hosts with that shard, but don't replicate on nodes with the shard. * * <p/> * Simple deploy policy which distributes the shards in round robin style.<b> * Following features are supported:<br> * - initial shard distribution<br> * - shard distribution when under replicated<br> * - shard removal when over-replicated <br> * <p/> * <p/> * Missing feature:<br> * - shard/node rebalancing<br> * <p/> * TODO jz: node load rebalancing */ public class HostAwareDistributionPolicy implements IDeployPolicy { private final static Logger LOG = Logger.getLogger(HostAwareDistributionPolicy.class); public Map<String, List<String>> createDistributionPlan(final Map<String, List<String>> currentShard2NodesMap, final Map<String, List<String>> currentNode2ShardsMap, List<String> aliveNodes, final int replicationLevel) { if (aliveNodes.size() == 0) { throw new IllegalArgumentException("no alive nodes to distribute to"); } Set<String> shards = currentShard2NodesMap.keySet(); for (String shard : shards) { Set<String> assignedNodes = new HashSet<String>(replicationLevel); int neededDeployments = replicationLevel - countValues(currentShard2NodesMap, shard); assignedNodes.addAll(currentShard2NodesMap.get(shard)); // now assign new nodes based on round robin algorithm List<String> sortedNodes = sortNodesByCapacityHostAware(aliveNodes, shard, currentNode2ShardsMap); CircularList<String> roundRobinNodes = new CircularList<String>(sortedNodes); neededDeployments = chooseNewNodes(currentNode2ShardsMap, roundRobinNodes, shard, assignedNodes, neededDeployments); if (neededDeployments > 0) { LOG.warn("cannot replicate shard '" + shard + "' " + replicationLevel + " times, cause only " + roundRobinNodes.size() + " nodes connected"); } else if (neededDeployments < 0) { LOG.info("found shard '" + shard + "' over-replicated"); // TODO jz: maybe we should add a configurable threshold tha e.g. 10% // over replication is ok ? removeOverreplicatedShards(currentShard2NodesMap, currentNode2ShardsMap, shard, neededDeployments); } } return currentNode2ShardsMap; } private void sortByFreeCapacity(List<String> nodes, final Map<String, List<String>> node2ShardsMap) { Collections.sort(nodes, new Comparator<String>() { @Override public int compare(String node1, String node2) { int size1 = node2ShardsMap.get(node1).size(); int size2 = node2ShardsMap.get(node2).size(); return (size1 < size2 ? -1 : (size1 == size2 ? 0 : 1)); } }); } private List<String> sortNodesByCapacityHostAware(List<String> aliveNodes, String shard, final Map<String, List<String>> node2ShardsMap) { List<String> sortedNodes = new ArrayList<String>(); List<String> aliveHosts = parseHosts(aliveNodes); Map<String, List<String>> host2ShardsMap = createHostToShardsMap(aliveHosts, node2ShardsMap); Map<String, List<String>> host2NodesMap = createHostToSortedNodesMap(aliveNodes, node2ShardsMap); sortHostsByCapacityForShard(aliveHosts, shard, host2ShardsMap); // TODO sort nodes in host2ShardsMap boolean nodesLeft = true; while (nodesLeft) { nodesLeft = false; for (String host : aliveHosts) { List<String> nodes = host2NodesMap.get(host); if (nodes.size() > 0) { nodesLeft = true; sortedNodes.add(nodes.remove(0)); } } } return sortedNodes; } private Map<String,List<String>> createHostToSortedNodesMap(List<String> aliveNodes, Map<String,List<String>> node2ShardsMap) { Map<String, List<String>> host2NodesMap = new HashMap<String, List<String>>(); for (String node : aliveNodes) { String host = parseHost(node); if (!host2NodesMap.containsKey(host)) { host2NodesMap.put(host, new ArrayList<String>()); } host2NodesMap.get(host).add(node); } for (String host : host2NodesMap.keySet()) { sortByFreeCapacity(host2NodesMap.get(host), node2ShardsMap); } return host2NodesMap; } private void sortHostsByCapacityForShard(List<String> aliveHosts, final String shard, final Map<String, List<String>> host2ShardsMap) { Collections.sort(aliveHosts, new Comparator<String>() { @Override public int compare(String host1, String host2) { if (host2ShardsMap.get(host1).contains(shard) && !host2ShardsMap.get(host2).contains(shard)) return 1; if (!host2ShardsMap.get(host1).contains(shard) && host2ShardsMap.get(host2).contains(shard)) return -1; int size1 = host2ShardsMap.get(host1).size(); int size2 = host2ShardsMap.get(host2).size(); return (size1 < size2 ? -1 : (size1 == size2 ? 0 : 1)); } }); } /** * creates map of hosts to list of shards. Hosts without shards contain an entry with an empty list. */ private Map<String, List<String>> createHostToShardsMap(List<String> aliveHosts, Map<String, List<String>> node2ShardsMap) { Map<String, List<String>> host2ShardsMap = new HashMap<String, List<String>>(); for (String host : aliveHosts) { host2ShardsMap.put(host, new ArrayList<String>()); } for (String node : node2ShardsMap.keySet()) { String host = parseHost(node); if (host2ShardsMap.containsKey(host)) { host2ShardsMap.get(host).addAll(node2ShardsMap.get(node)); } } return host2ShardsMap; } private List<String> parseHosts(List<String> aliveNodes) { Set<String> hosts = new HashSet<String>(); for (String node : aliveNodes) { String host = parseHost(node); hosts.add(host); } return new ArrayList<String>(hosts); } private String parseHost(String node) { return node.substring(0, node.indexOf(":")); } private int chooseNewNodes(final Map<String, List<String>> currentNode2ShardsMap, CircularList<String> roundRobinNodes, String shard, Set<String> assignedNodes, int neededDeployments) { String tailNode = roundRobinNodes.getTail(); String currentNode = null; while (neededDeployments > 0 && !tailNode.equals(currentNode)) { currentNode = roundRobinNodes.getNext(); if (!assignedNodes.contains(currentNode)) { if (LOG.isDebugEnabled()) { LOG.debug("assign " + shard + " to " + currentNode); } currentNode2ShardsMap.get(currentNode).add(shard); assignedNodes.add(currentNode); neededDeployments--; } } return neededDeployments; } private void removeOverreplicatedShards(final Map<String, List<String>> currentShard2NodesMap, final Map<String, List<String>> currentNode2ShardsMap, String shard, int neededDeployments) { while (neededDeployments < 0) { int maxShardServingCount = 0; String maxShardServingNode = null; List<String> nodeNames = currentShard2NodesMap.get(shard); for (String node : nodeNames) { int shardCount = countValues(currentNode2ShardsMap, node); if (shardCount > maxShardServingCount) { maxShardServingCount = shardCount; maxShardServingNode = node; } } currentNode2ShardsMap.get(maxShardServingNode).remove(shard); neededDeployments++; } } private int countValues(Map<String, List<String>> multiMap, String key) { List<String> list = multiMap.get(key); if (list == null) { return 0; } return list.size(); } }