// ================================================================================================= // Copyright 2011 Twitter, Inc. // ------------------------------------------------------------------------------------------------- // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this work except in compliance with the License. // You may obtain a copy of the License in the LICENSE file, or at: // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ================================================================================================= package com.twitter.common.zookeeper; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStream; import java.io.OutputStreamWriter; import java.io.Writer; import java.net.InetSocketAddress; import java.nio.charset.Charset; import java.util.Map; import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Charsets; import com.google.common.base.Joiner; import com.google.common.base.Optional; import com.google.common.base.Preconditions; import com.google.common.base.Predicate; import com.google.common.base.Predicates; import com.google.common.base.Throwables; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSortedSet; import com.google.common.collect.Iterables; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.google.common.collect.Sets.SetView; import com.google.common.util.concurrent.UncheckedExecutionException; import com.google.gson.Gson; import org.apache.zookeeper.KeeperException; import org.apache.zookeeper.KeeperException.NoNodeException; import org.apache.zookeeper.Watcher; import org.apache.zookeeper.ZooDefs; import org.apache.zookeeper.data.ACL; import com.twitter.common.args.Arg; import com.twitter.common.args.CmdLine; import com.twitter.common.base.Command; import com.twitter.common.base.Function; import com.twitter.common.base.Supplier; import com.twitter.common.io.Codec; import com.twitter.common.io.CompatibilityCodec; import com.twitter.common.io.ThriftCodec; import com.twitter.common.util.BackoffHelper; import com.twitter.common.zookeeper.Group.GroupChangeListener; import com.twitter.common.zookeeper.Group.JoinException; import com.twitter.common.zookeeper.Group.Membership; import com.twitter.common.zookeeper.Group.WatchException; import com.twitter.common.zookeeper.ZooKeeperClient.ZooKeeperConnectionException; import com.twitter.thrift.Endpoint; import com.twitter.thrift.ServiceInstance; import com.twitter.thrift.Status; import static com.google.common.base.Preconditions.checkNotNull; /** * ZooKeeper-backed implementation of {@link ServerSet}. */ public class ServerSetImpl implements ServerSet { private static final Logger LOG = Logger.getLogger(ServerSetImpl.class.getName()); @CmdLine(name = "serverset_encode_json", help = "If true, use JSON for encoding server set information." + " Defaults to true (use JSON).") private static final Arg<Boolean> ENCODE_JSON = Arg.create(true); private final ZooKeeperClient zkClient; private final Group group; private final Codec<ServiceInstance> codec; private final BackoffHelper backoffHelper; /** * Creates a new ServerSet using open ZooKeeper node ACLs. * * @param zkClient the client to use for interactions with ZooKeeper * @param path the name-service path of the service to connect to */ public ServerSetImpl(ZooKeeperClient zkClient, String path) { this(zkClient, ZooDefs.Ids.OPEN_ACL_UNSAFE, path); } /** * Creates a new ServerSet for the given service {@code path}. * * @param zkClient the client to use for interactions with ZooKeeper * @param acl the ACL to use for creating the persistent group path if it does not already exist * @param path the name-service path of the service to connect to */ public ServerSetImpl(ZooKeeperClient zkClient, Iterable<ACL> acl, String path) { this(zkClient, new Group(zkClient, acl, path), createDefaultCodec()); } /** * Creates a new ServerSet using the given service {@code group}. * * @param zkClient the client to use for interactions with ZooKeeper * @param group the server group */ public ServerSetImpl(ZooKeeperClient zkClient, Group group) { this(zkClient, group, createDefaultCodec()); } /** * Creates a new ServerSet using the given service {@code group} and a custom {@code codec}. * * @param zkClient the client to use for interactions with ZooKeeper * @param group the server group * @param codec a codec to use for serializing and de-serializing the ServiceInstance data to and * from a byte array */ public ServerSetImpl(ZooKeeperClient zkClient, Group group, Codec<ServiceInstance> codec) { this.zkClient = checkNotNull(zkClient); this.group = checkNotNull(group); this.codec = checkNotNull(codec); // TODO(John Sirois): Inject the helper so that backoff strategy can be configurable. backoffHelper = new BackoffHelper(); } @VisibleForTesting ZooKeeperClient getZkClient() { return zkClient; } @Override public EndpointStatus join( InetSocketAddress endpoint, Map<String, InetSocketAddress> additionalEndpoints) throws JoinException, InterruptedException { LOG.log(Level.WARNING, "Joining a ServerSet without a shard ID is deprecated and will soon break."); return join(endpoint, additionalEndpoints, Optional.<Integer>absent()); } @Override public EndpointStatus join( InetSocketAddress endpoint, Map<String, InetSocketAddress> additionalEndpoints, int shardId) throws JoinException, InterruptedException { return join(endpoint, additionalEndpoints, Optional.of(shardId)); } private EndpointStatus join( InetSocketAddress endpoint, Map<String, InetSocketAddress> additionalEndpoints, Optional<Integer> shardId) throws JoinException, InterruptedException { checkNotNull(endpoint); checkNotNull(additionalEndpoints); final MemberStatus memberStatus = new MemberStatus(endpoint, additionalEndpoints, shardId); Supplier<byte[]> serviceInstanceSupplier = new Supplier<byte[]>() { @Override public byte[] get() { return memberStatus.serializeServiceInstance(); } }; final Membership membership = group.join(serviceInstanceSupplier); return new EndpointStatus() { @Override public void update(Status status) throws UpdateException { checkNotNull(status); LOG.warning("This method is deprecated. Please use leave() instead."); if (status == Status.DEAD) { leave(); } else { LOG.warning("Status update has been ignored"); } } @Override public void leave() throws UpdateException { memberStatus.leave(membership); } }; } @Override public EndpointStatus join( InetSocketAddress endpoint, Map<String, InetSocketAddress> additionalEndpoints, Status status) throws JoinException, InterruptedException { LOG.warning("This method is deprecated. Please do not specify a status field."); if (status != Status.ALIVE) { LOG.severe("**************************************************************************\n" + "WARNING: MUTABLE STATUS FIELDS ARE NO LONGER SUPPORTED.\n" + "JOINING WITH STATUS ALIVE EVEN THOUGH YOU SPECIFIED " + status + "\n**************************************************************************"); } return join(endpoint, additionalEndpoints); } @Override public Command watch(HostChangeMonitor<ServiceInstance> monitor) throws MonitorException { ServerSetWatcher serverSetWatcher = new ServerSetWatcher(zkClient, monitor); try { return serverSetWatcher.watch(); } catch (WatchException e) { throw new MonitorException("ZooKeeper watch failed.", e); } catch (InterruptedException e) { throw new MonitorException("Interrupted while watching ZooKeeper.", e); } } @Override public void monitor(HostChangeMonitor<ServiceInstance> monitor) throws MonitorException { LOG.warning("This method is deprecated. Please use watch instead."); watch(monitor); } private class MemberStatus { private final InetSocketAddress endpoint; private final Map<String, InetSocketAddress> additionalEndpoints; private final Optional<Integer> shardId; private MemberStatus( InetSocketAddress endpoint, Map<String, InetSocketAddress> additionalEndpoints, Optional<Integer> shardId) { this.endpoint = endpoint; this.additionalEndpoints = additionalEndpoints; this.shardId = shardId; } synchronized void leave(Membership membership) throws UpdateException { try { membership.cancel(); } catch (JoinException e) { throw new UpdateException( "Failed to auto-cancel group membership on transition to DEAD status", e); } } byte[] serializeServiceInstance() { ServiceInstance serviceInstance = new ServiceInstance( ServerSets.toEndpoint(endpoint), Maps.transformValues(additionalEndpoints, ServerSets.TO_ENDPOINT), Status.ALIVE); if (shardId.isPresent()) { serviceInstance.setShard(shardId.get()); } LOG.fine("updating endpoint data to:\n\t" + serviceInstance); try { return ServerSets.serializeServiceInstance(serviceInstance, codec); } catch (IOException e) { throw new IllegalStateException("Unexpected problem serializing thrift struct " + serviceInstance + "to a byte[]", e); } } } private static class ServiceInstanceFetchException extends RuntimeException { ServiceInstanceFetchException(String message, Throwable cause) { super(message, cause); } } private static class ServiceInstanceDeletedException extends RuntimeException { ServiceInstanceDeletedException(String path) { super(path); } } private class ServerSetWatcher { private final ZooKeeperClient zkClient; private final HostChangeMonitor<ServiceInstance> monitor; @Nullable private ImmutableSet<ServiceInstance> serverSet; ServerSetWatcher(ZooKeeperClient zkClient, HostChangeMonitor<ServiceInstance> monitor) { this.zkClient = zkClient; this.monitor = monitor; } public Command watch() throws WatchException, InterruptedException { Watcher onExpirationWatcher = zkClient.registerExpirationHandler(new Command() { @Override public void execute() { // Servers may have changed Status while we were disconnected from ZooKeeper, check and // re-register our node watches. rebuildServerSet(); } }); try { return group.watch(new GroupChangeListener() { @Override public void onGroupChange(Iterable<String> memberIds) { notifyGroupChange(memberIds); } }); } catch (WatchException e) { zkClient.unregister(onExpirationWatcher); throw e; } catch (InterruptedException e) { zkClient.unregister(onExpirationWatcher); throw e; } } private ServiceInstance getServiceInstance(final String nodePath) { try { return backoffHelper.doUntilResult(new Supplier<ServiceInstance>() { @Override public ServiceInstance get() { try { byte[] data = zkClient.get().getData(nodePath, false, null); return ServerSets.deserializeServiceInstance(data, codec); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new ServiceInstanceFetchException( "Interrupted updating service data for: " + nodePath, e); } catch (ZooKeeperConnectionException e) { LOG.log(Level.WARNING, "Temporary error trying to updating service data for: " + nodePath, e); return null; } catch (NoNodeException e) { invalidateNodePath(nodePath); throw new ServiceInstanceDeletedException(nodePath); } catch (KeeperException e) { if (zkClient.shouldRetry(e)) { LOG.log(Level.WARNING, "Temporary error trying to update service data for: " + nodePath, e); return null; } else { throw new ServiceInstanceFetchException( "Failed to update service data for: " + nodePath, e); } } catch (IOException e) { throw new ServiceInstanceFetchException( "Failed to deserialize the ServiceInstance data for: " + nodePath, e); } } }); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new ServiceInstanceFetchException( "Interrupted trying to update service data for: " + nodePath, e); } } private final LoadingCache<String, ServiceInstance> servicesByMemberId = CacheBuilder.newBuilder().build(new CacheLoader<String, ServiceInstance>() { @Override public ServiceInstance load(String memberId) { return getServiceInstance(group.getMemberPath(memberId)); } }); private void rebuildServerSet() { Set<String> memberIds = ImmutableSet.copyOf(servicesByMemberId.asMap().keySet()); servicesByMemberId.invalidateAll(); notifyGroupChange(memberIds); } private String invalidateNodePath(String deletedPath) { String memberId = group.getMemberId(deletedPath); servicesByMemberId.invalidate(memberId); return memberId; } private final Function<String, ServiceInstance> MAYBE_FETCH_NODE = new Function<String, ServiceInstance>() { @Override public ServiceInstance apply(String memberId) { // This get will trigger a fetch try { return servicesByMemberId.getUnchecked(memberId); } catch (UncheckedExecutionException e) { Throwable cause = e.getCause(); if (!(cause instanceof ServiceInstanceDeletedException)) { Throwables.propagateIfInstanceOf(cause, ServiceInstanceFetchException.class); throw new IllegalStateException( "Unexpected error fetching member data for: " + memberId, e); } return null; } } }; private synchronized void notifyGroupChange(Iterable<String> memberIds) { ImmutableSet<String> newMemberIds = ImmutableSortedSet.copyOf(memberIds); Set<String> existingMemberIds = servicesByMemberId.asMap().keySet(); // Ignore no-op state changes except for the 1st when we've seen no group yet. if ((serverSet == null) || !newMemberIds.equals(existingMemberIds)) { SetView<String> deletedMemberIds = Sets.difference(existingMemberIds, newMemberIds); // Implicit removal from servicesByMemberId. existingMemberIds.removeAll(ImmutableSet.copyOf(deletedMemberIds)); Iterable<ServiceInstance> serviceInstances = Iterables.filter( Iterables.transform(newMemberIds, MAYBE_FETCH_NODE), Predicates.notNull()); notifyServerSetChange(ImmutableSet.copyOf(serviceInstances)); } } private void notifyServerSetChange(ImmutableSet<ServiceInstance> currentServerSet) { // ZK nodes may have changed if there was a session expiry for a server in the server set, but // if the server's status has not changed, we can skip any onChange updates. if (!currentServerSet.equals(serverSet)) { if (currentServerSet.isEmpty()) { LOG.warning("server set empty for path " + group.getPath()); } else { if (LOG.isLoggable(Level.INFO)) { if (serverSet == null) { LOG.info("received initial membership " + currentServerSet); } else { logChange(Level.INFO, currentServerSet); } } } serverSet = currentServerSet; monitor.onChange(serverSet); } } private void logChange(Level level, ImmutableSet<ServiceInstance> newServerSet) { StringBuilder message = new StringBuilder("server set " + group.getPath() + " change: "); if (serverSet.size() != newServerSet.size()) { message.append("from ").append(serverSet.size()) .append(" members to ").append(newServerSet.size()); } Joiner joiner = Joiner.on("\n\t\t"); SetView<ServiceInstance> left = Sets.difference(serverSet, newServerSet); if (!left.isEmpty()) { message.append("\n\tleft:\n\t\t").append(joiner.join(left)); } SetView<ServiceInstance> joined = Sets.difference(newServerSet, serverSet); if (!joined.isEmpty()) { message.append("\n\tjoined:\n\t\t").append(joiner.join(joined)); } LOG.log(level, message.toString()); } } private static class EndpointSchema { final String host; final Integer port; EndpointSchema(Endpoint endpoint) { Preconditions.checkNotNull(endpoint); this.host = endpoint.getHost(); this.port = endpoint.getPort(); } String getHost() { return host; } Integer getPort() { return port; } } private static class ServiceInstanceSchema { final EndpointSchema serviceEndpoint; final Map<String, EndpointSchema> additionalEndpoints; final Status status; final Integer shard; ServiceInstanceSchema(ServiceInstance instance) { this.serviceEndpoint = new EndpointSchema(instance.getServiceEndpoint()); if (instance.getAdditionalEndpoints() != null) { this.additionalEndpoints = Maps.transformValues( instance.getAdditionalEndpoints(), new Function<Endpoint, EndpointSchema>() { @Override public EndpointSchema apply(Endpoint endpoint) { return new EndpointSchema(endpoint); } } ); } else { this.additionalEndpoints = Maps.newHashMap(); } this.status = instance.getStatus(); this.shard = instance.isSetShard() ? instance.getShard() : null; } EndpointSchema getServiceEndpoint() { return serviceEndpoint; } Map<String, EndpointSchema> getAdditionalEndpoints() { return additionalEndpoints; } Status getStatus() { return status; } Integer getShard() { return shard; } } /** * An adapted JSON codec that makes use of {@link ServiceInstanceSchema} to circumvent the * __isset_bit_vector internal thrift struct field that tracks primitive types. */ private static class AdaptedJsonCodec implements Codec<ServiceInstance> { private static final Charset ENCODING = Charsets.UTF_8; private static final Class<ServiceInstanceSchema> CLASS = ServiceInstanceSchema.class; private final Gson gson = new Gson(); @Override public void serialize(ServiceInstance instance, OutputStream sink) throws IOException { Writer w = new OutputStreamWriter(sink, ENCODING); gson.toJson(new ServiceInstanceSchema(instance), CLASS, w); w.flush(); } @Override public ServiceInstance deserialize(InputStream source) throws IOException { ServiceInstanceSchema output = gson.fromJson(new InputStreamReader(source, ENCODING), CLASS); Endpoint primary = new Endpoint( output.getServiceEndpoint().getHost(), output.getServiceEndpoint().getPort()); Map<String, Endpoint> additional = Maps.transformValues( output.getAdditionalEndpoints(), new Function<EndpointSchema, Endpoint>() { @Override public Endpoint apply(EndpointSchema endpoint) { return new Endpoint(endpoint.getHost(), endpoint.getPort()); } } ); ServiceInstance instance = new ServiceInstance(primary, ImmutableMap.copyOf(additional), output.getStatus()); if (output.getShard() != null) { instance.setShard(output.getShard()); } return instance; } } private static Codec<ServiceInstance> createCodec(final boolean useJsonEncoding) { final Codec<ServiceInstance> json = new AdaptedJsonCodec(); final Codec<ServiceInstance> thrift = ThriftCodec.create(ServiceInstance.class, ThriftCodec.BINARY_PROTOCOL); final Predicate<byte[]> recognizer = new Predicate<byte[]>() { public boolean apply(byte[] input) { return (input.length > 1 && input[0] == '{' && input[1] == '\"') == useJsonEncoding; } }; if (useJsonEncoding) { return CompatibilityCodec.create(json, thrift, 2, recognizer); } return CompatibilityCodec.create(thrift, json, 2, recognizer); } /** * Creates a codec for {@link ServiceInstance} objects that uses Thrift binary encoding, and can * decode both Thrift and JSON encodings. * * @return a new codec instance. */ public static Codec<ServiceInstance> createThriftCodec() { return createCodec(false); } /** * Creates a codec for {@link ServiceInstance} objects that uses JSON encoding, and can decode * both Thrift and JSON encodings. * * @return a new codec instance. */ public static Codec<ServiceInstance> createJsonCodec() { return createCodec(true); } /** * Returns a codec for {@link ServiceInstance} objects that uses either the Thrift or the JSON * encoding, depending on whether the command line argument <tt>serverset_json_encofing</tt> is * set to <tt>true</tt>, and can decode both Thrift and JSON encodings. * * @return a new codec instance. */ public static Codec<ServiceInstance> createDefaultCodec() { return createCodec(ENCODE_JSON.get()); } }