package edu.berkeley.thebes.hat.client.clustering; import org.apache.thrift.TException; import org.apache.thrift.transport.TTransportException; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import edu.berkeley.thebes.common.clustering.RoutingHash; import edu.berkeley.thebes.common.config.Config; import edu.berkeley.thebes.common.config.ConfigParameterTypes.RoutingMode; import edu.berkeley.thebes.common.data.DataItem; import edu.berkeley.thebes.common.data.Version; import edu.berkeley.thebes.common.thrift.ServerAddress; import edu.berkeley.thebes.common.thrift.ThriftDataItem; import edu.berkeley.thebes.hat.common.thrift.ReplicaService; import edu.berkeley.thebes.hat.common.thrift.ThriftUtil; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; public class MasteredReplicaRouter extends ReplicaRouter { private final Map<Integer, List<ServerAddress>> replicaAddressesByCluster; private final Map<Integer, List<ReplicaService.Client>> syncReplicasByCluster; private final int numClusters; private final int numNeighbors; public MasteredReplicaRouter() throws TTransportException, IOException { assert(Config.getRoutingMode() == RoutingMode.MASTERED); this.replicaAddressesByCluster = Maps.newHashMap(); this.syncReplicasByCluster = Maps.newHashMap(); this.numClusters = Config.getNumClusters(); this.numNeighbors = Config.getServersInCluster().size(); for (int i = 0; i < numClusters; i ++) { List<ServerAddress> neighbors = Config.getServersInCluster(i+1); List<ReplicaService.Client> neighborClients = Lists.newArrayList(); for (ServerAddress neighbor : neighbors) { neighborClients.add(ThriftUtil.getReplicaServiceSyncClient(neighbor.getIP(), neighbor.getPort())); } replicaAddressesByCluster.put(i+1, neighbors); syncReplicasByCluster.put(i+1, neighborClients); } } private ReplicaService.Client getSyncReplicaByKey(String key) { int hash = RoutingHash.hashKey(key, numNeighbors); int clusterID = (hash % numClusters) + 1; return syncReplicasByCluster.get(clusterID).get(hash); } private ServerAddress getReplicaIPByKey(String key) { int hash = RoutingHash.hashKey(key, numNeighbors); int clusterID = RoutingHash.hashKey(key, numClusters) + 1; return replicaAddressesByCluster.get(clusterID).get(hash); } @Override public boolean put(String key, DataItem value) throws TException { try { return getSyncReplicaByKey(key).put(key, value.toThrift()); } catch (TException e) { throw new TException("Failed to write to " + getReplicaIPByKey(key), e); } } @Override public ThriftDataItem get(String key, Version requiredVersion) throws TException { try { return getSyncReplicaByKey(key).get(key, Version.toThrift(requiredVersion)); } catch (TException e) { throw new TException("Failed to read from " + getReplicaIPByKey(key), e); } } }