package org.corfudb.runtime.clients;
import com.codahale.metrics.Counter;
import com.codahale.metrics.Gauge;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.Timer;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.handler.ssl.SslContext;
import io.netty.util.concurrent.DefaultEventExecutorGroup;
import io.netty.util.concurrent.EventExecutorGroup;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.corfudb.protocols.wireprotocol.CorfuMsg;
import org.corfudb.protocols.wireprotocol.CorfuMsgType;
import org.corfudb.protocols.wireprotocol.NettyCorfuMessageDecoder;
import org.corfudb.protocols.wireprotocol.NettyCorfuMessageEncoder;
import org.corfudb.runtime.CorfuRuntime;
import org.corfudb.runtime.exceptions.NetworkException;
import org.corfudb.runtime.exceptions.WrongEpochException;
import org.corfudb.security.sasl.plaintext.PlainTextSaslNettyClient;
import org.corfudb.security.sasl.SaslUtils;
import org.corfudb.security.tls.TlsUtils;
import org.corfudb.util.CFUtils;
import org.corfudb.util.MetricsUtils;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
/**
* A client router which multiplexes operations over the Netty transport.
* <p>
* Created by mwei on 12/8/15.
*/
@Slf4j
@ChannelHandler.Sharable
public class NettyClientRouter extends SimpleChannelInboundHandler<CorfuMsg>
implements IClientRouter {
/**
* Metrics: meter (counter), histogram
*/
private Gauge<Integer> gaugeConnected;
private Timer timerConnect;
private Timer timerSyncOp;
private Counter counterConnectFailed;
private Counter counterSendDisconnected;
private Counter counterSendTimeout;
private Counter counterAsyncOpSent;
/**
* A random instance
*/
public static final Random random = new Random();
/**
* The epoch this router is in.
*/
@Getter
@Setter
public long epoch;
/**
* The id of this client.
*/
@Getter
@Setter
public UUID clientID;
/**
* New connection timeout (milliseconds)
*/
@Getter
@Setter
public long timeoutConnect;
/**
* Sync call response timeout (milliseconds)
*/
@Getter
@Setter
public long timeoutResponse;
/**
* Retry interval after timeout (milliseconds)
*/
@Getter
@Setter
public long timeoutRetry;
/**
* The current request ID.
*/
@Getter
public AtomicLong requestID;
/**
* The handlers registered to this router.
*/
public Map<CorfuMsgType, IClient> handlerMap;
/**
* The clients registered to this router.
*/
public List<IClient> clientList;
/**
* The outstanding requests on this router.
*/
public Map<Long, CompletableFuture> outstandingRequests;
/**
* The currently registered channel context.
*/
public ChannelHandlerContext context;
/**
* The currently registered channel.
*/
public Channel channel = null;
/**
* The worker group for this router.
*/
public EventLoopGroup workerGroup;
/**
* The event executor group for this router.
*/
public EventExecutorGroup ee;
/**
* Whether or not this router is shutdown.
*/
public volatile boolean shutdown;
/**
* The host that this router is routing requests for.
*/
@Getter
String host;
/**
* The port that this router is routing requests for.
*/
@Getter
Integer port;
/**
* Are we connected?
*/
@Getter
volatile Boolean connected;
private Bootstrap b;
private Boolean tlsEnabled = false;
private SslContext sslContext;
private Boolean saslPlainTextEnabled = false;
private String saslPlainTextUsernameFile;
private String saslPlainTextPasswordFile;
public NettyClientRouter(String endpoint) {
this(endpoint.split(":")[0], Integer.parseInt(endpoint.split(":")[1]),
false, null, null, null, null, false, null, null);
}
public NettyClientRouter(String host, Integer port) {
this(host, port, false, null, null, null, null, false, null, null);
}
public NettyClientRouter(String host, Integer port, Boolean tls,
String keyStore, String ksPasswordFile, String trustStore,
String tsPasswordFile, Boolean saslPlainText, String usernameFile,
String passwordFile) {
this.host = host;
this.port = port;
clientID = UUID.randomUUID();
connected = false;
timeoutConnect = 500;
timeoutResponse = 5000;
timeoutRetry = 1000;
handlerMap = new ConcurrentHashMap<>();
clientList = new ArrayList<>();
requestID = new AtomicLong();
outstandingRequests = new ConcurrentHashMap<>();
shutdown = true;
MetricRegistry metrics = CorfuRuntime.getMetrics();
String pfx = CorfuRuntime.getMpCR() + host + ":" + port.toString() + ".";
synchronized (metrics) {
if (!metrics.getNames().contains(pfx + "connected")) {
gaugeConnected = metrics.register(pfx + "connected", () -> connected ? 1 : 0);
}
}
timerConnect = metrics.timer(pfx + "connect");
timerSyncOp = metrics.timer(pfx + "sync-op");
counterConnectFailed = metrics.counter(pfx + "connect-failed");
counterSendDisconnected = metrics.counter(pfx + "send-disconnected");
counterSendTimeout = metrics.counter(pfx + "send-timeout");
counterAsyncOpSent = metrics.counter(pfx + "async-op-sent");
if (tls) {
sslContext =
TlsUtils.enableTls(TlsUtils.SslContextType.CLIENT_CONTEXT,
keyStore, e -> {
throw new RuntimeException("Could not read the key store " +
"password file: " + e.getClass().getSimpleName(), e);
},
ksPasswordFile, e -> {
throw new RuntimeException("Could not load keys from the key " +
"store: " + e.getClass().getSimpleName(), e);
},
trustStore, e -> {
throw new RuntimeException("Could not read the trust store " +
"password file: " + e.getClass().getSimpleName(), e);
},
tsPasswordFile, e -> {
throw new RuntimeException("Could not load keys from the trust " +
"store: " + e.getClass().getSimpleName(), e);
});
this.tlsEnabled = true;
}
if (saslPlainText) {
saslPlainTextUsernameFile = usernameFile;
saslPlainTextPasswordFile = passwordFile;
saslPlainTextEnabled = true;
}
addClient(new BaseClient());
start();
}
/**
* Add a new client to the router.
*
* @param client The client to add to the router.
* @return This NettyClientRouter, to support chaining and the builder pattern.
*/
public IClientRouter addClient(IClient client) {
// Set the client's router to this instance.
client.setRouter(this);
// Iterate through all types of CorfuMsgType, registering the handler
client.getHandledTypes().stream()
.forEach(x -> {
handlerMap.put(x, client);
log.trace("Registered {} to handle messages of type {}", client, x);
});
// Register this type
clientList.add(client);
return this;
}
/**
* Gets a client that matches a particular type.
*
* @param clientType The class of the client to match.
* @param <T> The type of the client to match.
* @return The first client that matches that type.
* @throws NoSuchElementException If there are no clients matching that type.
*/
@SuppressWarnings("unchecked")
public <T extends IClient> T getClient(Class<T> clientType) {
return (T) clientList.stream()
.filter(clientType::isInstance)
.findFirst().get();
}
public void start() {
start(-1);
}
public void start(long c) {
shutdown = false;
if (workerGroup == null
|| workerGroup.isShutdown()
|| !channel.isOpen()
) {
workerGroup = new NioEventLoopGroup(Runtime.getRuntime().availableProcessors() * 2, new ThreadFactory() {
final AtomicInteger threadNum = new AtomicInteger(0);
@Override
public Thread newThread(Runnable r) {
Thread t = new Thread(r);
t.setName("worker-" + threadNum.getAndIncrement());
t.setDaemon(true);
return t;
}
});
ee = new DefaultEventExecutorGroup(Runtime.getRuntime().availableProcessors() * 2, new ThreadFactory() {
final AtomicInteger threadNum = new AtomicInteger(0);
@Override
public Thread newThread(Runnable r) {
Thread t = new Thread(r);
t.setName(this.getClass().getName() + "event-" + threadNum.getAndIncrement());
t.setDaemon(true);
return t;
}
});
Bootstrap b = new Bootstrap();
b.group(workerGroup);
b.channel(NioSocketChannel.class);
b.option(ChannelOption.SO_KEEPALIVE, true);
b.option(ChannelOption.SO_REUSEADDR, true);
b.option(ChannelOption.TCP_NODELAY, true);
NettyClientRouter router = this;
b.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) throws Exception {
if (tlsEnabled) {
ch.pipeline().addLast("ssl", sslContext.newHandler(ch.alloc()));
}
ch.pipeline().addLast(new LengthFieldPrepender(4));
ch.pipeline().addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4));
if (saslPlainTextEnabled) {
PlainTextSaslNettyClient saslNettyClient =
SaslUtils.enableSaslPlainText(saslPlainTextUsernameFile, saslPlainTextPasswordFile);
ch.pipeline().addLast("sasl/plain-text", saslNettyClient);
}
ch.pipeline().addLast(ee, new NettyCorfuMessageDecoder());
ch.pipeline().addLast(ee, new NettyCorfuMessageEncoder());
ch.pipeline().addLast(ee, router);
}
});
try {
connectChannel(b, c);
} catch (Exception e) {
try {
// shutdown EventLoopGroup
workerGroup.shutdownGracefully().sync();
} catch (InterruptedException ie) {
}
throw new NetworkException(e.getClass().getSimpleName() +
" connecting to endpoint failed", host + ":" + port, e);
}
}
}
synchronized void connectChannel(Bootstrap b, long c) {
boolean isEnabled = MetricsUtils.isMetricsCollectionEnabled();
try (Timer.Context context = MetricsUtils.getConditionalContext(isEnabled, timerConnect)) {
ChannelFuture cf = b.connect(host, port);
cf.syncUninterruptibly();
if (!cf.awaitUninterruptibly(timeoutConnect)) {
cf.channel().close(); // close port
MetricsUtils.incConditionalCounter(isEnabled, counterConnectFailed, 1);
throw new NetworkException(c + " Timeout connecting to endpoint", host + ":" + port);
}
channel = cf.channel();
}
channel.closeFuture().addListener((r) -> {
connected = false;
outstandingRequests.forEach((ReqID, reqCF) -> {
MetricsUtils.incConditionalCounter(isEnabled, counterSendDisconnected, 1);
reqCF.completeExceptionally(new NetworkException("Disconnected", host + ":" + port));
outstandingRequests.remove(ReqID);
});
if (!shutdown) {
log.trace("Disconnected, reconnecting...");
while (!shutdown) {
try {
connectChannel(b, c);
return;
} catch (Exception ex) {
MetricsUtils.incConditionalCounter(isEnabled, counterConnectFailed, 1);
log.trace("Exception while reconnecting, retry in {} ms", timeoutRetry);
Thread.sleep(timeoutRetry);
}
}
}
});
connected = true;
}
/**
* Stops routing requests.
*/
@Override
public void stop() {
stop(false);
}
@Override
public void stop(boolean shutdown) {
// A very hasty check of Netty state-of-the-art is that shutting down
// the worker threads is tricksy or impossible.
this.shutdown = shutdown;
connected = false;
if (shutdown) {
try {
ChannelFuture cf = channel.close();
cf.syncUninterruptibly();
cf.awaitUninterruptibly(1000);
} catch (Exception e) {
log.error("Error in closing channel");
}
try {
ee.shutdownGracefully().sync();
workerGroup.shutdownGracefully().sync();
} catch (InterruptedException e) {
log.error("Interrupted exception in shutting event pool : {}", e);
}
} else {
ChannelFuture cf = channel.disconnect();
cf.syncUninterruptibly();
boolean b1 = cf.awaitUninterruptibly(1000);
}
}
/**
* Send a message and get a completable future to be fulfilled by the reply.
*
* @param ctx The channel handler context to send the message under.
* @param message The message to send.
* @param <T> The type of completable to return.
* @return A completable future which will be fulfilled by the reply,
* or a timeout in the case there is no response.
*/
public <T> CompletableFuture<T> sendMessageAndGetCompletable(ChannelHandlerContext ctx, CorfuMsg message) {
boolean isEnabled = MetricsUtils.isMetricsCollectionEnabled();
if (!connected) {
log.trace("Disconnected endpoint " + host + ":" + port);
MetricsUtils.incConditionalCounter(isEnabled, counterSendDisconnected, 1);
throw new NetworkException("Disconnected endpoint", host + ":" + port);
} else {
Timer.Context context = MetricsUtils.getConditionalContext(isEnabled, timerSyncOp);
// Get the next request ID.
final long thisRequest = requestID.getAndIncrement();
// Set the message fields.
message.setClientID(clientID);
message.setRequestID(thisRequest);
message.setEpoch(epoch);
// Generate a future and put it in the completion table.
final CompletableFuture<T> cf = new CompletableFuture<>();
outstandingRequests.put(thisRequest, cf);
// Write the message out to the channel.
if (ctx == null) {
channel.writeAndFlush(message);
} else {
ctx.writeAndFlush(message);
}
log.trace("Sent message: {}", message);
final CompletableFuture<T> cfElapsed = cf.thenApply(x -> {
MetricsUtils.stopConditionalContext(context);
return x;
});
// Generate a timeout future, which will complete exceptionally if the main future is not completed.
final CompletableFuture<T> cfTimeout = CFUtils.within(cfElapsed, Duration.ofMillis(timeoutResponse));
cfTimeout.exceptionally(e -> {
MetricsUtils.incConditionalCounter(isEnabled, counterSendTimeout, 1);
outstandingRequests.remove(thisRequest);
log.debug("Remove request {} due to timeout!", thisRequest);
return null;
});
return cfTimeout;
}
}
/**
* Send a one way message, without adding a completable future.
*
* @param ctx The context to send the message under.
* @param message The message to send.
*/
public void sendMessage(ChannelHandlerContext ctx, CorfuMsg message) {
ChannelHandlerContext outContext = context;
if (ctx == null) {
if (context == null) {
// if the router's context is not set, return a failure
log.warn("Attempting to send on a channel that is not ready.");
return;
}
outContext = context;
}
// Get the next request ID.
final long thisRequest = requestID.getAndIncrement();
// Set the base fields for this message.
message.setClientID(clientID);
message.setRequestID(thisRequest);
message.setEpoch(epoch);
// Write this message out on the channel.
outContext.writeAndFlush(message);
MetricsUtils.incConditionalCounter(MetricsUtils.isMetricsCollectionEnabled(), counterAsyncOpSent, 1);
log.trace("Sent one-way message: {}", message);
}
/**
* Send a netty message through this router, setting the fields in the outgoing message.
*
* @param ctx Channel handler context to use.
* @param inMsg Incoming message to respond to.
* @param outMsg Outgoing message.
*/
public void sendResponseToServer(ChannelHandlerContext ctx, CorfuMsg inMsg, CorfuMsg outMsg) {
outMsg.copyBaseFields(inMsg);
outMsg.setEpoch(epoch);
ctx.writeAndFlush(outMsg);
log.trace("Sent response: {}", outMsg);
}
/**
* Complete a given outstanding request with a completion value.
*
* @param requestID The request to complete.
* @param completion The value to complete the request with
* @param <T> The type of the completion.
*/
@SuppressWarnings("unchecked")
public <T> void completeRequest(long requestID, T completion) {
CompletableFuture<T> cf;
if ((cf = (CompletableFuture<T>) outstandingRequests.get(requestID)) != null) {
cf.complete(completion);
outstandingRequests.remove(requestID);
} else {
log.warn("Attempted to complete request {}, but request not outstanding!", requestID);
}
}
/**
* Exceptionally complete a request with a given cause.
*
* @param requestID The request to complete.
* @param cause The cause to give for the exceptional completion.
*/
public void completeExceptionally(long requestID, Throwable cause) {
CompletableFuture cf;
if ((cf = outstandingRequests.get(requestID)) != null) {
cf.completeExceptionally(cause);
outstandingRequests.remove(requestID);
} else {
log.warn("Attempted to exceptionally complete request {}, but request not outstanding!", requestID);
}
}
/**
* Validate the epoch of a CorfuMsg, and send a WRONG_EPOCH response if
* the server is in the wrong epoch. Ignored if the message type is reset (which
* is valid in any epoch).
*
* @param msg The incoming message to validate.
* @param ctx The context of the channel handler.
* @return True, if the epoch is correct, but false otherwise.
*/
private boolean validateEpochAndClientID(CorfuMsg msg, ChannelHandlerContext ctx) {
// Check if the message is intended for us. If not, drop the message.
if (!msg.getClientID().equals(clientID)) {
log.warn("Incoming message intended for client {}, our id is {}, dropping!", msg.getClientID(), clientID);
return false;
}
// Check if the message is in the right epoch.
if (!msg.getMsgType().ignoreEpoch && msg.getEpoch() != epoch) {
log.trace("Incoming message with wrong epoch, got {}, expected {}, message was: {}",
msg.getEpoch(), epoch, msg);
/* If this message was pending a completion, complete it with an error. */
completeExceptionally(msg.getRequestID(), new WrongEpochException(msg.getEpoch()));
return false;
}
return true;
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, CorfuMsg m) throws Exception {
try {
// We get the handler for this message from the map
IClient handler = handlerMap.get(m.getMsgType());
if (handler == null) {
// The message was unregistered, we are dropping it.
log.warn("Received unregistered message {}, dropping", m);
} else {
if (validateEpochAndClientID(m, ctx)) {
// Route the message to the handler.
log.trace("Message routed to {}: {}", handler.getClass().getSimpleName(), m);
handler.handleMessage(m, ctx);
}
}
} catch (Exception e) {
log.error("Exception during read!", e);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
log.error("Exception during channel handling.", cause);
ctx.close();
}
@Override
public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
context = ctx;
log.debug("Registered new channel {}", ctx);
}
@Override
public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
super.channelUnregistered(ctx);
context = null;
log.debug("Unregistered channel {}", ctx);
}
}