/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.runtime.query.netty; import akka.dispatch.Futures; import com.google.common.util.concurrent.ThreadFactoryBuilder; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.ByteBuf; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.LengthFieldBasedFrameDecoder; import io.netty.handler.stream.ChunkedWriteHandler; import org.apache.flink.runtime.io.network.netty.NettyBufferPool; import org.apache.flink.runtime.query.KvStateID; import org.apache.flink.runtime.query.KvStateServerAddress; import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; import org.apache.flink.util.Preconditions; import scala.concurrent.Future; import scala.concurrent.Promise; import java.nio.channels.ClosedChannelException; import java.util.ArrayDeque; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; /** * Netty-based client querying {@link KvStateServer} instances. * * <p>This client can be used by multiple threads concurrently. Operations are * executed asynchronously and return Futures to their result. * * <p>The incoming pipeline looks as follows: * <pre> * Socket.read() -> LengthFieldBasedFrameDecoder -> KvStateServerHandler * </pre> * * <p>Received binary messages are expected to contain a frame length field. Netty's * {@link LengthFieldBasedFrameDecoder} is used to fully receive the frame before * giving it to our {@link KvStateClientHandler}. * * <p>Connections are established and closed by the client. The server only * closes the connection on a fatal failure that cannot be recovered. */ public class KvStateClient { /** Netty's Bootstrap. */ private final Bootstrap bootstrap; /** Statistics tracker */ private final KvStateRequestStats stats; /** Established connections. */ private final ConcurrentHashMap<KvStateServerAddress, EstablishedConnection> establishedConnections = new ConcurrentHashMap<>(); /** Pending connections. */ private final ConcurrentHashMap<KvStateServerAddress, PendingConnection> pendingConnections = new ConcurrentHashMap<>(); /** Atomic shut down flag. */ private final AtomicBoolean shutDown = new AtomicBoolean(); /** * Creates a client with the specified number of event loop threads. * * @param numEventLoopThreads Number of event loop threads (minimum 1). */ public KvStateClient(int numEventLoopThreads, KvStateRequestStats stats) { Preconditions.checkArgument(numEventLoopThreads >= 1, "Non-positive number of event loop threads."); NettyBufferPool bufferPool = new NettyBufferPool(numEventLoopThreads); ThreadFactory threadFactory = new ThreadFactoryBuilder() .setDaemon(true) .setNameFormat("Flink KvStateClient Event Loop Thread %d") .build(); NioEventLoopGroup nioGroup = new NioEventLoopGroup(numEventLoopThreads, threadFactory); this.bootstrap = new Bootstrap() .group(nioGroup) .channel(NioSocketChannel.class) .option(ChannelOption.ALLOCATOR, bufferPool) .handler(new ChannelInitializer<SocketChannel>() { @Override protected void initChannel(SocketChannel ch) throws Exception { ch.pipeline() .addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) // ChunkedWriteHandler respects Channel writability .addLast(new ChunkedWriteHandler()); } }); this.stats = Preconditions.checkNotNull(stats, "Statistics tracker"); } /** * Returns a future holding the serialized request result. * * <p>If the server does not serve a KvState instance with the given ID, * the Future will be failed with a {@link UnknownKvStateID}. * * <p>If the KvState instance does not hold any data for the given key * and namespace, the Future will be failed with a {@link UnknownKeyOrNamespace}. * * <p>All other failures are forwarded to the Future. * * @param serverAddress Address of the server to query * @param kvStateId ID of the KvState instance to query * @param serializedKeyAndNamespace Serialized key and namespace to query KvState instance with * @return Future holding the serialized result */ public Future<byte[]> getKvState( KvStateServerAddress serverAddress, KvStateID kvStateId, byte[] serializedKeyAndNamespace) { if (shutDown.get()) { return Futures.failed(new IllegalStateException("Shut down")); } EstablishedConnection connection = establishedConnections.get(serverAddress); if (connection != null) { return connection.getKvState(kvStateId, serializedKeyAndNamespace); } else { PendingConnection pendingConnection = pendingConnections.get(serverAddress); if (pendingConnection != null) { // There was a race, use the existing pending connection. return pendingConnection.getKvState(kvStateId, serializedKeyAndNamespace); } else { // We try to connect to the server. PendingConnection pending = new PendingConnection(serverAddress); PendingConnection previous = pendingConnections.putIfAbsent(serverAddress, pending); if (previous == null) { // OK, we are responsible to connect. bootstrap.connect(serverAddress.getHost(), serverAddress.getPort()) .addListener(pending); return pending.getKvState(kvStateId, serializedKeyAndNamespace); } else { // There was a race, use the existing pending connection. return previous.getKvState(kvStateId, serializedKeyAndNamespace); } } } } /** * Shuts down the client and closes all connections. * * <p>After a call to this method, all returned futures will be failed. */ public void shutDown() { if (shutDown.compareAndSet(false, true)) { for (Map.Entry<KvStateServerAddress, EstablishedConnection> conn : establishedConnections.entrySet()) { if (establishedConnections.remove(conn.getKey(), conn.getValue())) { conn.getValue().close(); } } for (Map.Entry<KvStateServerAddress, PendingConnection> conn : pendingConnections.entrySet()) { if (pendingConnections.remove(conn.getKey()) != null) { conn.getValue().close(); } } if (bootstrap != null) { EventLoopGroup group = bootstrap.group(); if (group != null) { group.shutdownGracefully(0, 10, TimeUnit.SECONDS); } } } } /** * Closes the connection to the given server address if it exists. * * <p>If there is a request to the server a new connection will be established. * * @param serverAddress Target address of the connection to close */ public void closeConnection(KvStateServerAddress serverAddress) { PendingConnection pending = pendingConnections.get(serverAddress); if (pending != null) { pending.close(); } EstablishedConnection established = establishedConnections.remove(serverAddress); if (established != null) { established.close(); } } /** * A pending connection that is in the process of connecting. */ private class PendingConnection implements ChannelFutureListener { /** Lock to guard the connect call, channel hand in, etc. */ private final Object connectLock = new Object(); /** Address of the server we are connecting to. */ private final KvStateServerAddress serverAddress; /** Queue of requests while connecting. */ private final ArrayDeque<PendingRequest> queuedRequests = new ArrayDeque<>(); /** The established connection after the connect succeeds. */ private EstablishedConnection established; /** Closed flag. */ private boolean closed; /** Failure cause if something goes wrong. */ private Throwable failureCause; /** * Creates a pending connection to the given server. * * @param serverAddress Address of the server to connect to. */ private PendingConnection(KvStateServerAddress serverAddress) { this.serverAddress = serverAddress; } @Override public void operationComplete(ChannelFuture future) throws Exception { // Callback from the Bootstrap's connect call. if (future.isSuccess()) { handInChannel(future.channel()); } else { close(future.cause()); } } /** * Returns a future holding the serialized request result. * * <p>If the channel has been established, forward the call to the * established channel, otherwise queue it for when the channel is * handed in. * * @param kvStateId ID of the KvState instance to query * @param serializedKeyAndNamespace Serialized key and namespace to query KvState instance * with * @return Future holding the serialized result */ public Future<byte[]> getKvState(KvStateID kvStateId, byte[] serializedKeyAndNamespace) { synchronized (connectLock) { if (failureCause != null) { return Futures.failed(failureCause); } else if (closed) { return Futures.failed(new ClosedChannelException()); } else { if (established != null) { return established.getKvState(kvStateId, serializedKeyAndNamespace); } else { // Queue this and handle when connected PendingRequest pending = new PendingRequest(kvStateId, serializedKeyAndNamespace); queuedRequests.add(pending); return pending.promise.future(); } } } } /** * Hands in a channel after a successful connection. * * @param channel Channel to hand in */ private void handInChannel(Channel channel) { synchronized (connectLock) { if (closed || failureCause != null) { // Close the channel and we are done. Any queued requests // are removed on the close/failure call and after that no // new ones can be enqueued. channel.close(); } else { established = new EstablishedConnection(serverAddress, channel); PendingRequest pending; while ((pending = queuedRequests.poll()) != null) { Future<byte[]> resultFuture = established.getKvState( pending.kvStateId, pending.serializedKeyAndNamespace); pending.promise.completeWith(resultFuture); } // Publish the channel for the general public establishedConnections.put(serverAddress, established); pendingConnections.remove(serverAddress); // Check shut down for possible race with shut down. We // don't want any lingering connections after shut down, // which can happen if we don't check this here. if (shutDown.get()) { if (establishedConnections.remove(serverAddress, established)) { established.close(); } } } } } /** * Close the connecting channel with a ClosedChannelException. */ private void close() { close(new ClosedChannelException()); } /** * Close the connecting channel with an Exception (can be * <code>null</code>) or forward to the established channel. */ private void close(Throwable cause) { synchronized (connectLock) { if (!closed) { if (failureCause == null) { failureCause = cause; } if (established != null) { established.close(); } else { PendingRequest pending; while ((pending = queuedRequests.poll()) != null) { pending.promise.tryFailure(cause); } } closed = true; } } } /** * A pending request queued while the channel is connecting. */ private final class PendingRequest { private final KvStateID kvStateId; private final byte[] serializedKeyAndNamespace; private final Promise<byte[]> promise; private PendingRequest(KvStateID kvStateId, byte[] serializedKeyAndNamespace) { this.kvStateId = kvStateId; this.serializedKeyAndNamespace = serializedKeyAndNamespace; this.promise = Futures.promise(); } } @Override public String toString() { synchronized (connectLock) { return "PendingConnection{" + "serverAddress=" + serverAddress + ", queuedRequests=" + queuedRequests.size() + ", established=" + (established != null) + ", closed=" + closed + '}'; } } } /** * An established connection that wraps the actual channel instance and is * registered at the {@link KvStateClientHandler} for callbacks. */ private class EstablishedConnection implements KvStateClientHandlerCallback { /** Address of the server we are connected to. */ private final KvStateServerAddress serverAddress; /** The actual TCP channel. */ private final Channel channel; /** Pending requests keyed by request ID. */ private final ConcurrentHashMap<Long, PromiseAndTimestamp> pendingRequests = new ConcurrentHashMap<>(); /** Current request number used to assign unique request IDs. */ private final AtomicLong requestCount = new AtomicLong(); /** Reference to a failure that was reported by the channel. */ private final AtomicReference<Throwable> failureCause = new AtomicReference<>(); /** * Creates an established connection with the given channel. * * @param serverAddress Address of the server connected to * @param channel The actual TCP channel */ EstablishedConnection(KvStateServerAddress serverAddress, Channel channel) { this.serverAddress = Preconditions.checkNotNull(serverAddress, "KvStateServerAddress"); this.channel = Preconditions.checkNotNull(channel, "Channel"); // Add the client handler with the callback channel.pipeline().addLast("KvStateClientHandler", new KvStateClientHandler(this)); stats.reportActiveConnection(); } /** * Close the channel with a ClosedChannelException. */ void close() { close(new ClosedChannelException()); } /** * Close the channel with a cause. * * @param cause The cause to close the channel with. * @return Channel close future */ private boolean close(Throwable cause) { if (failureCause.compareAndSet(null, cause)) { channel.close(); stats.reportInactiveConnection(); for (long requestId : pendingRequests.keySet()) { PromiseAndTimestamp pending = pendingRequests.remove(requestId); if (pending != null && pending.promise.tryFailure(cause)) { stats.reportFailedRequest(); } } return true; } return false; } /** * Returns a future holding the serialized request result. * * @param kvStateId ID of the KvState instance to query * @param serializedKeyAndNamespace Serialized key and namespace to query KvState instance * with * @return Future holding the serialized result */ Future<byte[]> getKvState(KvStateID kvStateId, byte[] serializedKeyAndNamespace) { PromiseAndTimestamp requestPromiseTs = new PromiseAndTimestamp( Futures.<byte[]>promise(), System.nanoTime()); try { final long requestId = requestCount.getAndIncrement(); pendingRequests.put(requestId, requestPromiseTs); stats.reportRequest(); ByteBuf buf = KvStateRequestSerializer.serializeKvStateRequest( channel.alloc(), requestId, kvStateId, serializedKeyAndNamespace); channel.writeAndFlush(buf).addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { if (!future.isSuccess()) { // Fail promise if not failed to write PromiseAndTimestamp pending = pendingRequests.remove(requestId); if (pending != null && pending.promise.tryFailure(future.cause())) { stats.reportFailedRequest(); } } } }); // Check failure for possible race. We don't want any lingering // promises after a failure, which can happen if we don't check // this here. Note that close is treated as a failure as well. Throwable failure = failureCause.get(); if (failure != null) { // Remove from pending requests to guard against concurrent // removal and to make sure that we only count it once as failed. PromiseAndTimestamp p = pendingRequests.remove(requestId); if (p != null && p.promise.tryFailure(failure)) { stats.reportFailedRequest(); } } } catch (Throwable t) { requestPromiseTs.promise.tryFailure(t); } return requestPromiseTs.promise.future(); } @Override public void onRequestResult(long requestId, byte[] serializedValue) { PromiseAndTimestamp pending = pendingRequests.remove(requestId); if (pending != null && pending.promise.trySuccess(serializedValue)) { long durationMillis = (System.nanoTime() - pending.timestamp) / 1_000_000; stats.reportSuccessfulRequest(durationMillis); } } @Override public void onRequestFailure(long requestId, Throwable cause) { PromiseAndTimestamp pending = pendingRequests.remove(requestId); if (pending != null && pending.promise.tryFailure(cause)) { stats.reportFailedRequest(); } } @Override public void onFailure(Throwable cause) { if (close(cause)) { // Remove from established channels, otherwise future // requests will be handled by this failed channel. establishedConnections.remove(serverAddress, this); } } @Override public String toString() { return "EstablishedConnection{" + "serverAddress=" + serverAddress + ", channel=" + channel + ", pendingRequests=" + pendingRequests.size() + ", requestCount=" + requestCount + ", failureCause=" + failureCause + '}'; } /** * Pair of promise and a timestamp. */ private class PromiseAndTimestamp { private final Promise<byte[]> promise; private final long timestamp; public PromiseAndTimestamp(Promise<byte[]> promise, long timestamp) { this.promise = promise; this.timestamp = timestamp; } } } }