package org.corfudb.runtime.clients; import com.google.common.collect.ImmutableMap; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.handler.codec.LengthFieldBasedFrameDecoder; import io.netty.handler.codec.LengthFieldPrepender; import io.netty.handler.logging.LogLevel; import io.netty.handler.logging.LoggingHandler; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslHandler; import io.netty.util.concurrent.DefaultEventExecutorGroup; import io.netty.util.concurrent.EventExecutorGroup; import lombok.Data; import lombok.extern.slf4j.Slf4j; import org.corfudb.AbstractCorfuTest; import org.corfudb.infrastructure.BaseServer; import org.corfudb.infrastructure.NettyServerRouter; import org.corfudb.protocols.wireprotocol.NettyCorfuMessageDecoder; import org.corfudb.protocols.wireprotocol.NettyCorfuMessageEncoder; import org.corfudb.security.sasl.plaintext.PlainTextSaslNettyServer; import org.corfudb.security.tls.TlsUtils; import org.junit.Test; import java.io.IOException; import java.net.ServerSocket; import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicInteger; import javax.net.ssl.SSLEngine; import static org.assertj.core.api.Assertions.assertThat; /** * Created by mwei on 3/28/16. */ @Slf4j public class NettyCommTest extends AbstractCorfuTest { private Integer findRandomOpenPort() throws IOException { try ( ServerSocket socket = new ServerSocket(0); ) { return socket.getLocalPort(); } } @Test public void nettyServerClientPingable() throws Exception { runWithBaseServer( (port) -> { return new NettyServerData(port); }, (port) -> { return new NettyClientRouter("localhost", port); }, (r, d) -> { assertThat(r.getClient(BaseClient.class).pingSync()) .isTrue(); }); } @Test public void nettyServerClientPingableAfterFailure() throws Exception { runWithBaseServer( (port) -> { return new NettyServerData(port); }, (port) -> { return new NettyClientRouter("localhost", port); }, (r, d) -> { assertThat(r.getClient(BaseClient.class).pingSync()) .isTrue(); d.shutdownServer(); d.bootstrapServer(); r.getClient(BaseClient.class).pingSync(); }); } @Test public void nettyTlsNoMutualAuth() throws Exception { runWithBaseServer( (port) -> { NettyServerData d = new NettyServerData(port); String[] ciphers = {"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"}; String[] protocols = {"TLSv1.2"}; d.enableTls( "src/test/resources/security/s1.jks", "src/test/resources/security/storepass", "src/test/resources/security/s1.jks", "src/test/resources/security/storepass", false, ciphers, protocols); return d; }, (port) -> { return new NettyClientRouter("localhost", port, true, "src/test/resources/security/r1.jks", "src/test/resources/security/storepass", "src/test/resources/security/trust1.jks", "src/test/resources/security/storepass", false, null, null); }, (r, d) -> { assertThat(r.getClient(BaseClient.class).pingSync()) .isTrue(); }); } @Test public void nettyTlsMutualAuth() throws Exception { runWithBaseServer( (port) -> { NettyServerData d = new NettyServerData(port); String[] ciphers = {"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"}; String[] protocols = {"TLSv1.2"}; d.enableTls( "src/test/resources/security/s1.jks", "src/test/resources/security/storepass", "src/test/resources/security/trust1.jks", "src/test/resources/security/storepass", true, ciphers, protocols); return d; }, (port) -> { return new NettyClientRouter("localhost", port, true, "src/test/resources/security/r1.jks", "src/test/resources/security/storepass", "src/test/resources/security/trust1.jks", "src/test/resources/security/storepass", false, null, null); }, (r, d) -> { assertThat(r.getClient(BaseClient.class).pingSync()) .isTrue(); }); } @Test public void nettyTlsUnknownServer() throws Exception { runWithBaseServer( (port) -> { NettyServerData d = new NettyServerData(port); String[] ciphers = {"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"}; String[] protocols = {"TLSv1.2"}; d.enableTls( "src/test/resources/security/s3.jks", "src/test/resources/security/storepass", "src/test/resources/security/trust1.jks", "src/test/resources/security/storepass", true, ciphers, protocols); return d; }, (port) -> { return new NettyClientRouter("localhost", port, true, "src/test/resources/security/r1.jks", "src/test/resources/security/storepass", "src/test/resources/security/trust2.jks", "src/test/resources/security/storepass", false, null, null); }, (r, d) -> { assertThat(r.getClient(BaseClient.class).pingSync()) .isFalse(); }); } @Test public void nettyTlsUnknownClient() throws Exception { runWithBaseServer( (port) -> { NettyServerData d = new NettyServerData(port); String[] ciphers = {"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"}; String[] protocols = {"TLSv1.2"}; d.enableTls( "src/test/resources/security/s1.jks", "src/test/resources/security/storepass", "src/test/resources/security/trust2.jks", "src/test/resources/security/storepass", true, ciphers, protocols); return d; }, (port) -> { return new NettyClientRouter("localhost", port, true, "src/test/resources/security/r2.jks", "src/test/resources/security/storepass", "src/test/resources/security/trust1.jks", "src/test/resources/security/storepass", false, null, null); }, (r, d) -> { assertThat(r.getClient(BaseClient.class).pingSync()) .isFalse(); }); } @Test public void nettyTlsUnknownClientNoMutualAuth() throws Exception { runWithBaseServer( (port) -> { NettyServerData d = new NettyServerData(port); String[] ciphers = {"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"}; String[] protocols = {"TLSv1.2"}; d.enableTls( "src/test/resources/security/s1.jks", "src/test/resources/security/storepass", "src/test/resources/security/trust2.jks", "src/test/resources/security/storepass", false, ciphers, protocols); return d; }, (port) -> { return new NettyClientRouter("localhost", port, true, "src/test/resources/security/r2.jks", "src/test/resources/security/storepass", "src/test/resources/security/trust1.jks", "src/test/resources/security/storepass", false, null, null); }, (r, d) -> { assertThat(r.getClient(BaseClient.class).pingSync()) .isTrue(); }); } @Test public void nettySasl() throws Exception { runWithBaseServer( (port) -> { System.setProperty("java.security.auth.login.config", "src/test/resources/security/corfudb_jaas.config"); NettyServerData d = new NettyServerData(port); String[] ciphers = {"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"}; String[] protocols = {"TLSv1.2"}; d.enableTls( "src/test/resources/security/s1.jks", "src/test/resources/security/storepass", "src/test/resources/security/trust1.jks", "src/test/resources/security/storepass", true, ciphers, protocols); d.enableSaslPlainTextAuth(); return d; }, (port) -> { return new NettyClientRouter("localhost", port, true, "src/test/resources/security/r1.jks", "src/test/resources/security/storepass", "src/test/resources/security/trust1.jks", "src/test/resources/security/storepass", true, "src/test/resources/security/username1", "src/test/resources/security/userpass1"); }, (r, d) -> { assertThat(r.getClient(BaseClient.class).pingSync()) .isTrue(); }); } @Test public void nettySaslWrongPassword() throws Exception { runWithBaseServer( (port) -> { System.setProperty("java.security.auth.login.config", "src/test/resources/security/corfudb_jaas.config"); NettyServerData d = new NettyServerData(port); String[] ciphers = {"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"}; String[] protocols = {"TLSv1.2"}; d.enableTls( "src/test/resources/security/s1.jks", "src/test/resources/security/storepass", "src/test/resources/security/trust1.jks", "src/test/resources/security/storepass", true, ciphers, protocols); d.enableSaslPlainTextAuth(); return d; }, (port) -> { return new NettyClientRouter("localhost", port, true, "src/test/resources/security/r1.jks", "src/test/resources/security/storepass", "src/test/resources/security/trust1.jks", "src/test/resources/security/storepass", true, "src/test/resources/security/username1", "src/test/resources/security/userpass2"); }, (r, d) -> { assertThat(r.getClient(BaseClient.class).pingSync()) .isFalse(); }); } void runWithBaseServer(NettyServerDataConstructor nsdc, NettyClientRouterConstructor ncrc, NettyCommFunction actionFn) throws Exception { NettyServerRouter nsr = new NettyServerRouter(new ImmutableMap.Builder<String, Object>().build()); nsr.addServer(new BaseServer()); int port = findRandomOpenPort(); NettyServerData d = nsdc.createNettyServerData(port); NettyClientRouter ncr = null; try { d.bootstrapServer(); ncr = ncrc.createNettyClientRouter(port); ncr.addClient(new BaseClient()); ncr.start(); actionFn.runTest(ncr, d); } catch (Exception ex) { log.error("Exception ", ex); throw ex; } finally { try { if (ncr != null) {ncr.stop();} } catch (Exception ex) { log.warn("Error shutting down client...", ex); } d.shutdownServer(); } } @FunctionalInterface public interface NettyServerDataConstructor { NettyServerData createNettyServerData(int port) throws Exception; } @FunctionalInterface public interface NettyClientRouterConstructor { NettyClientRouter createNettyClientRouter(int port) throws Exception; } @FunctionalInterface public interface NettyCommFunction { void runTest(NettyClientRouter r, NettyServerData d) throws Exception; } @Data public class NettyServerData { ServerBootstrap b; ChannelFuture f; int port; EventLoopGroup bossGroup; EventLoopGroup workerGroup; EventExecutorGroup ee; boolean tlsEnabled = false; SslContext sslContext; boolean tlsMutualAuthEnabled = false; String[] enabledTlsCipherSuites; String[] enabledTlsProtocols; boolean saslPlainTextAuthEnabled = false; public NettyServerData(int port) { this.port = port; } public void enableTls(String ksFile, String ksPasswordFile, String tsFile, String tsPasswordFile, boolean mutualAuth, String[] ciphers, String[] protocols) throws Exception { this.sslContext = TlsUtils.enableTls(TlsUtils.SslContextType.SERVER_CONTEXT, ksFile, e -> { throw new RuntimeException("Could not load keys from the key " + "store: " + e.getClass().getSimpleName(), e); }, ksPasswordFile, e -> { throw new RuntimeException("Could not read the key store " + "password file: " + e.getClass().getSimpleName(), e); }, tsFile, e -> { throw new RuntimeException("Could not load keys from the trust " + "store: " + e.getClass().getSimpleName(), e); }, tsPasswordFile, e -> { throw new RuntimeException("Could not read the trust store " + "password file: " + e.getClass().getSimpleName(), e); }); this.tlsMutualAuthEnabled = mutualAuth; this.enabledTlsCipherSuites = ciphers; this.enabledTlsProtocols = protocols; this.tlsEnabled = true; } public void enableSaslPlainTextAuth() { this.saslPlainTextAuthEnabled = true; } void bootstrapServer() throws Exception { NettyServerRouter nsr = new NettyServerRouter(new ImmutableMap.Builder<String, Object>().build()); bossGroup = new NioEventLoopGroup(1, new ThreadFactory() { final AtomicInteger threadNum = new AtomicInteger(0); @Override public Thread newThread(Runnable r) { Thread t = new Thread(r); t.setName("accept-" + threadNum.getAndIncrement()); return t; } }); workerGroup = new NioEventLoopGroup(Runtime.getRuntime().availableProcessors() * 2, new ThreadFactory() { final AtomicInteger threadNum = new AtomicInteger(0); @Override public Thread newThread(Runnable r) { Thread t = new Thread(r); t.setName("io-" + threadNum.getAndIncrement()); return t; } }); ee = new DefaultEventExecutorGroup(Runtime.getRuntime().availableProcessors() * 2, new ThreadFactory() { final AtomicInteger threadNum = new AtomicInteger(0); @Override public Thread newThread(Runnable r) { Thread t = new Thread(r); t.setName("event-" + threadNum.getAndIncrement()); return t; } }); final int SO_BACKLOG = 100; final int FRAME_SIZE = 4; b = new ServerBootstrap(); b.group(bossGroup, workerGroup) .channel(NioServerSocketChannel.class) .option(ChannelOption.SO_BACKLOG, SO_BACKLOG) .childOption(ChannelOption.SO_KEEPALIVE, true) .childOption(ChannelOption.SO_REUSEADDR, true) .childOption(ChannelOption.TCP_NODELAY, true) .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) .handler(new LoggingHandler(LogLevel.INFO)) .childHandler(new ChannelInitializer<SocketChannel>() { @Override public void initChannel(io.netty.channel.socket.SocketChannel ch) throws Exception { if (tlsEnabled) { SSLEngine engine = sslContext.newEngine(ch.alloc()); engine.setEnabledCipherSuites(enabledTlsCipherSuites); engine.setEnabledProtocols(enabledTlsProtocols); if (tlsMutualAuthEnabled) { engine.setNeedClientAuth(true); } ch.pipeline().addLast("ssl", new SslHandler(engine)); } ch.pipeline().addLast(new LengthFieldPrepender(FRAME_SIZE)); ch.pipeline().addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, FRAME_SIZE, 0, FRAME_SIZE)); if (saslPlainTextAuthEnabled) { ch.pipeline().addLast("sasl/plain-text", new PlainTextSaslNettyServer()); } ch.pipeline().addLast(ee, new NettyCorfuMessageDecoder()); ch.pipeline().addLast(ee, new NettyCorfuMessageEncoder()); ch.pipeline().addLast(ee, nsr); } }); f = b.bind(port).sync(); } public void shutdownServer() { f.channel().close().awaitUninterruptibly(); bossGroup.shutdownGracefully(); workerGroup.shutdownGracefully(); } } }