package edu.berkeley.thebes.hat.client.clustering; import org.apache.thrift.TException; import org.apache.thrift.transport.TTransportException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import edu.berkeley.thebes.common.clustering.RoutingHash; import edu.berkeley.thebes.common.config.Config; import edu.berkeley.thebes.common.config.ConfigParameterTypes.RoutingMode; import edu.berkeley.thebes.common.data.DataItem; import edu.berkeley.thebes.common.data.Version; import edu.berkeley.thebes.common.thrift.ServerAddress; import edu.berkeley.thebes.common.thrift.ThriftDataItem; import edu.berkeley.thebes.hat.common.thrift.ReplicaService; import edu.berkeley.thebes.hat.common.thrift.ThriftUtil; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.SortedSet; import java.util.TreeSet; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicLong; public class NearestReplicaRouter extends ReplicaRouter { private static Logger logger = LoggerFactory.getLogger(NearestReplicaRouter.class); private static final double ALPHA = .95; private static final double TIME_BETWEEN_CHECKS = 10000; private static final double WARNING_THRESHOLD = 10; private List<ReplicaService.Client> syncReplicas; private static ConcurrentMap<ServerAddress, Double> averageLatencyByServer = Maps.newConcurrentMap(); private static AtomicLong timeSinceLastCheck; public NearestReplicaRouter() throws TTransportException, IOException { assert(Config.getRoutingMode() == RoutingMode.NEAREST); List<ServerAddress> serverIPs = Config.getServersInCluster(); syncReplicas = new ArrayList<ReplicaService.Client>(serverIPs.size()); for (ServerAddress server : serverIPs) { syncReplicas.add(ThriftUtil.getReplicaServiceSyncClient(server.getIP(), server.getPort())); averageLatencyByServer.putIfAbsent(server, 0d); logger.trace("Connected to " + server); } timeSinceLastCheck = new AtomicLong(); } private ReplicaService.Client getSyncReplicaByKey(String key) { return syncReplicas.get(RoutingHash.hashKey(key, syncReplicas.size())); } private ServerAddress getReplicaIPByKey(String key) { return Config.getServersInCluster().get(RoutingHash.hashKey(key, syncReplicas.size())); } @Override public boolean put(String key, DataItem value) throws TException { try { ServerAddress serverAddress = getReplicaIPByKey(key); ReplicaService.Client serverClient = getSyncReplicaByKey(key); long startTime = System.currentTimeMillis(); boolean ret = serverClient.put(key, value.toThrift()); long duration = System.currentTimeMillis() - startTime; averageLatencyByServer.put(serverAddress, averageLatencyByServer.get(serverAddress)*ALPHA + duration*(1-ALPHA)); checkLatencies(); return ret; } catch (TException e) { throw new TException("Failed to write to " + getReplicaIPByKey(key), e); } } private void checkLatencies() { if (timeSinceLastCheck.get() == 0) { timeSinceLastCheck.set(System.currentTimeMillis()); } if (System.currentTimeMillis() - timeSinceLastCheck.get() < TIME_BETWEEN_CHECKS) { return; } else { timeSinceLastCheck.set(System.currentTimeMillis()); } double minLatency = -1; for (double latency : averageLatencyByServer.values()) { if (minLatency == -1 || latency < minLatency) { minLatency = latency; } } for (ServerAddress server : averageLatencyByServer.keySet()) { double latency = averageLatencyByServer.get(server); if (latency > WARNING_THRESHOLD * minLatency) { logger.warn("Server " + server + " has high avg put latency: " + latency + " (min avg latency=" + minLatency + ")"); } else if (logger.isDebugEnabled()) { logger.debug("Server " + server + " has avg put latency: " + latency); } } } @Override public ThriftDataItem get(String key, Version requiredVersion) throws TException { try { return getSyncReplicaByKey(key).get(key, Version.toThrift(requiredVersion)); } catch (TException e) { throw new TException("Failed to read from " + getReplicaIPByKey(key), e); } } }