package com.lambdaworks.redis;
import static com.lambdaworks.redis.ConnectionEventTrigger.local;
import static com.lambdaworks.redis.ConnectionEventTrigger.remote;
import static com.lambdaworks.redis.PlainChannelInitializer.INITIALIZING_CMD_BUILDER;
import static com.lambdaworks.redis.PlainChannelInitializer.pingBeforeActivate;
import static com.lambdaworks.redis.PlainChannelInitializer.removeIfExists;
import java.io.IOException;
import java.io.InputStream;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import javax.net.ssl.*;
import com.lambdaworks.redis.event.EventBus;
import com.lambdaworks.redis.event.connection.ConnectedEvent;
import com.lambdaworks.redis.event.connection.ConnectionActivatedEvent;
import com.lambdaworks.redis.event.connection.DisconnectedEvent;
import com.lambdaworks.redis.internal.LettuceAssert;
import com.lambdaworks.redis.protocol.AsyncCommand;
import io.netty.channel.*;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
/**
* Connection builder for SSL connections. This class is part of the internal API.
*
* @author Mark Paluch
*/
public class SslConnectionBuilder extends ConnectionBuilder {
private RedisURI redisURI;
public SslConnectionBuilder ssl(RedisURI redisURI) {
this.redisURI = redisURI;
return this;
}
public static SslConnectionBuilder sslConnectionBuilder() {
return new SslConnectionBuilder();
}
@Override
protected List<ChannelHandler> buildHandlers() {
LettuceAssert.assertState(redisURI != null, "RedisURI must not be null");
LettuceAssert.assertState(redisURI.isSsl(), "RedisURI is not configured for SSL (ssl is false)");
return super.buildHandlers();
}
@Override
public RedisChannelInitializer build() {
final List<ChannelHandler> channelHandlers = buildHandlers();
return new SslChannelInitializer(clientOptions().isPingBeforeActivateConnection(), channelHandlers, redisURI,
clientResources().eventBus(), clientOptions().getSslOptions());
}
/**
* @author Mark Paluch
*/
static class SslChannelInitializer extends io.netty.channel.ChannelInitializer<Channel> implements RedisChannelInitializer {
private final boolean pingBeforeActivate;
private final List<ChannelHandler> handlers;
private final RedisURI redisURI;
private final EventBus eventBus;
private final SslOptions sslOptions;
private CompletableFuture<Boolean> initializedFuture = new CompletableFuture<>();
public SslChannelInitializer(boolean pingBeforeActivate, List<ChannelHandler> handlers, RedisURI redisURI,
EventBus eventBus, SslOptions sslOptions) {
this.pingBeforeActivate = pingBeforeActivate;
this.handlers = handlers;
this.redisURI = redisURI;
this.eventBus = eventBus;
this.sslOptions = sslOptions;
}
@Override
protected void initChannel(Channel channel) throws Exception {
SSLParameters sslParams = new SSLParameters();
SslContextBuilder sslContextBuilder = SslContextBuilder.forClient().sslProvider(sslOptions.getSslProvider());
if (redisURI.isVerifyPeer()) {
sslParams.setEndpointIdentificationAlgorithm("HTTPS");
} else {
sslContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE);
}
if (sslOptions.getTruststore() != null) {
try (InputStream is = sslOptions.getTruststore().openStream()) {
sslContextBuilder.trustManager(createTrustManagerFactory(is,
sslOptions.getTruststorePassword().length == 0 ? null : sslOptions.getTruststorePassword()));
}
}
SslContext sslContext = sslContextBuilder.build();
SSLEngine sslEngine = sslContext.newEngine(channel.alloc(), redisURI.getHost(), redisURI.getPort());
sslEngine.setSSLParameters(sslParams);
removeIfExists(channel.pipeline(), SslHandler.class);
if (channel.pipeline().get("first") == null) {
channel.pipeline().addFirst("first", new ChannelDuplexHandler() {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
eventBus.publish(new ConnectedEvent(local(ctx), remote(ctx)));
super.channelActive(ctx);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
eventBus.publish(new DisconnectedEvent(local(ctx), remote(ctx)));
super.channelInactive(ctx);
}
});
}
SslHandler sslHandler = new SslHandler(sslEngine, redisURI.isStartTls());
channel.pipeline().addLast(sslHandler);
if (channel.pipeline().get("channelActivator") == null) {
channel.pipeline().addLast("channelActivator", new RedisChannelInitializerImpl() {
private AsyncCommand<?, ?, ?> pingCommand;
@Override
public CompletableFuture<Boolean> channelInitialized() {
return initializedFuture;
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
initializedFuture = new CompletableFuture<>();
pingCommand = null;
super.channelInactive(ctx);
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
if (initializedFuture.isDone()) {
super.channelActive(ctx);
}
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof SslHandshakeCompletionEvent && !initializedFuture.isDone()) {
SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt;
if (event.isSuccess()) {
if (pingBeforeActivate) {
if (redisURI.getPassword() != null && redisURI.getPassword().length != 0) {
pingCommand = new AsyncCommand<>(
INITIALIZING_CMD_BUILDER.auth(new String(redisURI.getPassword())));
} else {
pingCommand = new AsyncCommand<>(INITIALIZING_CMD_BUILDER.ping());
}
pingBeforeActivate(pingCommand, initializedFuture, ctx, handlers);
} else {
ctx.fireChannelActive();
}
} else {
initializedFuture.completeExceptionally(event.cause());
}
}
if (evt instanceof ConnectionEvents.Close) {
if (ctx.channel().isOpen()) {
ctx.channel().close();
}
}
if (evt instanceof ConnectionEvents.Activated) {
if (!initializedFuture.isDone()) {
initializedFuture.complete(true);
eventBus.publish(new ConnectionActivatedEvent(local(ctx), remote(ctx)));
}
}
super.userEventTriggered(ctx, evt);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
if (cause instanceof SSLHandshakeException || cause.getCause() instanceof SSLException) {
initializedFuture.completeExceptionally(cause);
}
super.exceptionCaught(ctx, cause);
}
});
}
for (ChannelHandler handler : handlers) {
removeIfExists(channel.pipeline(), handler.getClass());
channel.pipeline().addLast(handler);
}
}
@Override
public CompletableFuture<Boolean> channelInitialized() {
return initializedFuture;
}
private static TrustManagerFactory createTrustManagerFactory(InputStream inputStream, char[] storePassword)
throws GeneralSecurityException, IOException {
KeyStore trustStore = KeyStore.getInstance(KeyStore.getDefaultType());
try {
trustStore.load(inputStream, storePassword);
} finally {
inputStream.close();
}
TrustManagerFactory trustManagerFactory = TrustManagerFactory
.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(trustStore);
return trustManagerFactory;
}
}
}