/* * Copyright (C) 2012-2016 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.nifty.server; import com.facebook.nifty.client.FramedClientConnector; import com.facebook.nifty.client.NettyClientConfig; import com.facebook.nifty.client.NiftyClient; import com.facebook.nifty.client.TNiftyClientChannelTransport; import com.facebook.nifty.core.NettyServerConfig; import com.facebook.nifty.core.NettyServerTransport; import com.facebook.nifty.core.RequestContext; import com.facebook.nifty.core.RequestContexts; import com.facebook.nifty.core.ThriftServerDefBuilder; import com.facebook.nifty.ssl.OpenSslServerConfiguration; import com.facebook.nifty.ssl.PollingMultiFileWatcher; import com.facebook.nifty.ssl.SslClientConfiguration; import com.facebook.nifty.ssl.SslConfigFileWatcher; import com.facebook.nifty.ssl.SslServerConfiguration; import com.facebook.nifty.ssl.TicketSeedFileParser; import com.facebook.nifty.ssl.TransportAttachObserver; import com.facebook.nifty.test.LogEntry; import com.facebook.nifty.test.ResultCode; import com.facebook.nifty.test.scribe; import com.google.common.collect.ImmutableList; import com.google.common.io.Files; import io.airlift.log.Logger; import io.airlift.units.Duration; import org.apache.thrift.TApplicationException; import org.apache.thrift.TException; import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.protocol.TProtocol; import org.apache.thrift.transport.TTransport; import org.apache.thrift.transport.TTransportException; import org.apache.tomcat.jni.SessionTicketKey; import org.jboss.netty.channel.group.DefaultChannelGroup; import org.jboss.netty.handler.ssl.HackyJdkSslClientContext; import org.jboss.netty.handler.ssl.SslContext; import org.jboss.netty.handler.ssl.SslHandler; import org.testng.Assert; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; import javax.net.ssl.KeyManager; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLException; import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSocket; import javax.net.ssl.TrustManagerFactory; import javax.security.auth.x500.X500Principal; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.net.Socket; import java.security.KeyStore; import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; import java.util.Arrays; import java.util.List; import java.util.concurrent.TimeUnit; import java.util.function.Function; import static java.util.Objects.requireNonNull; public class TestNiftyOpenSslServer { private static final Logger log = Logger.get(TestNiftyOpenSslServer.class); private NettyServerTransport server; private int port; private PollingMultiFileWatcher fileWatcher = null; // Server-side configs private File ticketSeedFile = null; private File privateKeyFile = null; private File serverCertFile = null; // Client-side configs private File clientCertFile = null; private File clientPKCS12File = null; // Password provided to the openssl command line tool when creating the client.pkcs12 file private static final String CLIENT_PKCS12_PASSWORD = "12345"; @BeforeMethod(alwaysRun = true) public void setup() { server = null; fileWatcher = new PollingMultiFileWatcher(Duration.valueOf("0 ms"), Duration.valueOf("100 ms")); } @AfterMethod(alwaysRun = true) public void teardown() throws InterruptedException { if (server != null) { server.stop(); } fileWatcher = null; deleteFilesIfExistIgnoreErrors( ticketSeedFile, privateKeyFile, serverCertFile, clientCertFile, clientPKCS12File); ticketSeedFile = privateKeyFile = serverCertFile = clientCertFile = clientPKCS12File = null; } private void startServer() { startServer(false); } private void startServer(boolean allowPlaintext) { try { List<SessionTicketKey> ticketKeysList = new TicketSeedFileParser().parse(getTicketSeedFile()); SessionTicketKey[] ticketKeys = ticketKeysList.toArray(new SessionTicketKey[ticketKeysList.size()]); SslConfigFileWatcher configUpdater = new SslConfigFileWatcher( getTicketSeedFile(), getPrivateKeyFile(), getServerCertFile(), null, fileWatcher); SslServerConfiguration config = createSSLServerConfiguration(allowPlaintext, ticketKeys); long callbacksSucceeded = fileWatcher.getStats().getCallbacksSucceeded(); startServer(getThriftServerDefBuilder(config, configUpdater)); while (fileWatcher.getStats().getCallbacksSucceeded() < callbacksSucceeded + 1) { Thread.sleep(25); // Wait for first callback to process } } catch (InterruptedException | IOException e) { throw new RuntimeException(e); } } private void startServer(final ThriftServerDefBuilder thriftServerDefBuilder) { server = new NettyServerTransport(thriftServerDefBuilder.build(), NettyServerConfig.newBuilder().build(), new DefaultChannelGroup()); server.start(); port = ((InetSocketAddress)server.getServerChannel().getLocalAddress()).getPort(); } SslServerConfiguration createSSLServerConfiguration(boolean allowPlaintext, SessionTicketKey[] ticketKeys) throws IOException { return OpenSslServerConfiguration.newBuilder() .certFile(getServerCertFile()) .keyFile(getPrivateKeyFile()) .allowPlaintext(allowPlaintext) .ticketKeys(ticketKeys) .build(); } private ThriftServerDefBuilder getThriftServerDefBuilder( SslServerConfiguration sslServerConfiguration, TransportAttachObserver configUpdater) { return getThriftServerDefBuilder(sslServerConfiguration, configUpdater, (List<LogEntry> entries) -> ResultCode.OK); } private ThriftServerDefBuilder getThriftServerDefBuilder( SslServerConfiguration sslServerConfiguration, TransportAttachObserver configUpdater, final Function<List<LogEntry>, ResultCode> thriftHandler) { requireNonNull(thriftHandler); return new ThriftServerDefBuilder() .listen(0) .withSSLConfiguration(sslServerConfiguration) .withTransportAttachObserver(configUpdater) .withProcessor(new scribe.Processor<>(new scribe.Iface() { @Override public ResultCode Log(List<LogEntry> messages) throws TException { RequestContext context = RequestContexts.getCurrentContext(); for (LogEntry message : messages) { log.info("[Client: %s] %s: %s", context.getConnectionContext().getRemoteAddress(), message.getCategory(), message.getMessage()); } try { return thriftHandler.apply(messages); } catch (Exception e) { throw new TException(e); } } })); } private SslClientConfiguration getClientSSLConfiguration() throws IOException { return getClientSSLConfiguration(null); } private SslClientConfiguration getClientSSLConfiguration(File certFile) throws IOException { return getClientSSLConfiguration(certFile, null); } private SslClientConfiguration getClientSSLConfiguration(File certFile, KeyManager[] keyManagers) throws IOException { SslContext context = new HackyJdkSslClientContext( null, certFile == null ? getServerCertFile() : certFile, keyManagers, null, null, null, 10000, 10000 ); return new SslClientConfiguration.Builder().sslContext(context).build(); } private scribe.Client makeNiftyClient(SslClientConfiguration clientSSLConfiguration) throws TTransportException, InterruptedException { NettyClientConfig config = NettyClientConfig.newBuilder() .setSSLClientConfiguration(clientSSLConfiguration).build(); InetSocketAddress address = new InetSocketAddress("localhost", port); TTransport transport = new NiftyClient(config) .connectSync(scribe.Client.class, new FramedClientConnector(address)); TProtocol protocol = new TBinaryProtocol(transport); return new scribe.Client(protocol); } private scribe.Client makeNiftyPlaintextClient() throws TTransportException, InterruptedException { NettyClientConfig config = NettyClientConfig.newBuilder().build(); InetSocketAddress address = new InetSocketAddress("localhost", port); TTransport transport = new NiftyClient(config) .connectSync(scribe.Client.class, new FramedClientConnector(address)); TProtocol protocol = new TBinaryProtocol(transport); return new scribe.Client(protocol); } /** * Returns a file path to the given resource loaded using the given class's class loader. * * @param clazz the class whose class loader should be used to load the resource. * @param resourcePath the resource path. * @return a File object representing the path to the resource. */ private File getResourceFile(Class<?> clazz, String resourcePath) { return new File(clazz.getResource(resourcePath).getFile()); } /** * Returns the contents of the given resource loaded using the given class's class loader. * * @param clazz the class whose class loader should be used to load the resource. * @param resourcePath the resource path. * @return the contents of the resource file. * @throws IOException if the resource file could not be read. */ private byte[] getResourceFileContents(Class<?> clazz, String resourcePath) throws IOException { return Files.toByteArray(getResourceFile(clazz, resourcePath)); } /** * Overwrites the contents of the given file with the given byte array. If the file does not exist, it will * be created. * * @param file the file to overwrite. * @param newContents new file contents. * @throws IOException if the write fails. */ private void overwriteFile(File file, byte[] newContents) throws IOException { java.nio.file.Files.write(file.toPath(), newContents); } /** * Best-effort attempt to delete all of the given files if they exist. Ignores errors. * * @param files the files to delete. */ private void deleteFilesIfExistIgnoreErrors(File... files) { for (File file : files) { if (file != null) { try { java.nio.file.Files.deleteIfExists(file.toPath()); } catch (IOException e) { // silently ignore delete errors } } } } /** * Creates a temp file with the same contents as the given resource. Returns the path to the temp file. * The temp file should be deleted by the user when the test finishes. * * @param clazz the class whose class loader should be used to load the resource. * @param resourcePath the resource path. * @return a File object representing the path to the new temp file. * @throws IOException if the resource file could not be read, or temp file could not be created or written. */ private File initTempFileFromResource(Class<?> clazz, String resourcePath) throws IOException { File result = File.createTempFile("test_nifty_openssl_server", resourcePath.replaceAll("/", "_")); overwriteFile(result, getResourceFileContents(clazz, resourcePath)); return result; } /** * Returns the path to a temporary ticket seed file. If the temp file does not yet exist, it is created on * demand and initialized with the contents of the "/ticket_seeds.json" resource. * The temp file should be deleted by the user when the test finishes. * * @return the new file. * @throws IOException if reading the resource or creating the temp file fails. */ private File getTicketSeedFile() throws IOException { if (ticketSeedFile == null) { ticketSeedFile = initTempFileFromResource(Plain.class, "/ticket_seeds.json"); } return ticketSeedFile; } /** * Overwrites the contents of the ticket seed file with the given byte array. * * @param newContents new ticket seed file contents. * @throws IOException if writing the file fails. */ private void updateTicketSeedFile(byte[] newContents) throws IOException { overwriteFile(getTicketSeedFile(), newContents); } /** * Returns the path to a temporary private key file. If the temp file does not yet exist, it is created on * demand and initialized with the contents of the "/rsa.key" resource. * The temp file should be deleted by the user when the test finishes. * * @return the new file. * @throws IOException if reading the resource or creating the temp file fails. */ private File getPrivateKeyFile() throws IOException { if (privateKeyFile == null) { privateKeyFile = initTempFileFromResource(Plain.class, "/rsa.key"); } return privateKeyFile; } /** * Overwrites the contents of the private key file with the given byte array. * * @param newContents new private key file contents. * @throws IOException if writing the file fails. */ private void updatePrivateKeyFile(byte[] newContents) throws IOException { overwriteFile(getPrivateKeyFile(), newContents); } /** * Returns the path to a temporary server certificate file. If the temp file does not yet exist, * it is created on demand and initialized with the contents of the "/rsa.crt" resource. * The temp file should be deleted by the user when the test finishes. * * @return the new file. * @throws IOException if reading the resource or creating the temp file fails. */ private File getServerCertFile() throws IOException { if (serverCertFile == null) { serverCertFile = initTempFileFromResource(Plain.class, "/rsa.crt"); } return serverCertFile; } /** * Overwrites the contents of the server certificate file with the given byte array. * * @param newContents new certificate file contents. * @throws IOException if writing the file fails. */ private void updateServerCertFile(byte[] newContents) throws IOException { overwriteFile(getServerCertFile(), newContents); } /** * Returns the path to a temporary client certificate file. If the temp file does not yet exist, * it is created on demand and initialized with the contents of the "/client.crt" resource. * The temp file should be deleted by the user when the test finishes. * * @return the new file. * @throws IOException if reading the resource or creating the temp file fails. */ private File getClientCertFile() throws IOException { if (clientCertFile == null) { clientCertFile = initTempFileFromResource(Plain.class, "/client.crt"); } return clientCertFile; } /** * Overwrites the contents of the certificate file with the given byte array. * * @param newContents new certificate file contents. * @throws IOException if writing the file fails. */ private void updateClientCertFile(byte[] newContents) throws IOException { overwriteFile(getClientCertFile(), newContents); } /** * Returns the path to a temporary client PKCS12 key file. If the temp file does not yet exist, * it is created ondemand and initialized with the contents of the "/client.pkcs12" resource. * The temp file should be deleted by the user when the test finishes. * * @return the new file. * @throws IOException if reading the resource or creating the temp file fails. */ private File getClientPKCS12File() throws IOException { if (clientPKCS12File == null) { clientPKCS12File = initTempFileFromResource(Plain.class, "/client.pkcs12"); } return clientPKCS12File; } /** * Overwrites the contents of the client PKCS12 key file with the given byte array. * * @param newContents new certificate file contents. * @throws IOException if writing the file fails. */ private void updateClientPKCS12File(byte[] newContents) throws IOException { overwriteFile(getClientPKCS12File(), newContents); } /** * Asserts that the given lists of session ticket keys are the same. {@link SessionTicketKey} seems to not * implement a proper equals() method so we have to do this the hard way. * * @param actualKeys the actual ticket keys. * @param expectedKeys the expected ticket keys. */ private void assertTicketKeysEqual(List<SessionTicketKey> actualKeys, List<SessionTicketKey> expectedKeys) { Assert.assertEquals(actualKeys.size(), expectedKeys.size()); for (int i = 0; i < actualKeys.size(); ++i) { SessionTicketKey actualKey = actualKeys.get(i); SessionTicketKey expectedKey = expectedKeys.get(i); Assert.assertEquals(actualKey.getAesKey(), expectedKey.getAesKey()); Assert.assertEquals(actualKey.getHmacKey(), expectedKey.getHmacKey()); Assert.assertEquals(actualKey.getName(), expectedKey.getName()); } } @Test public void testSSL() throws InterruptedException, TException, IOException { startServer(); scribe.Client client1 = makeNiftyClient(getClientSSLConfiguration()); Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "aaa"))), ResultCode.OK); Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "bbb"))), ResultCode.OK); scribe.Client client2 = makeNiftyClient(getClientSSLConfiguration()); Assert.assertEquals(client2.Log(Arrays.asList(new LogEntry("client2", "ccc"))), ResultCode.OK); } @Test public void testSSLWithPlaintextAllowedServer() throws InterruptedException, TException, IOException { startServer(true); scribe.Client client1 = makeNiftyClient(getClientSSLConfiguration()); Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "aaa"))), ResultCode.OK); Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "bbb"))), ResultCode.OK); scribe.Client client2 = makeNiftyClient(getClientSSLConfiguration()); Assert.assertEquals(client2.Log(Arrays.asList(new LogEntry("client2", "ccc"))), ResultCode.OK); } @Test(expectedExceptions = TTransportException.class) public void testUnencryptedClient() throws InterruptedException, TException { startServer(); scribe.Client client = makeNiftyPlaintextClient(); client.Log(Arrays.asList(new LogEntry("client2", "aaa"))); client.Log(Arrays.asList(new LogEntry("client2", "bbb"))); client.Log(Arrays.asList(new LogEntry("client2", "ccc"))); } @Test public void testUnencryptedClientWithAllowPlaintextServer() throws InterruptedException, TException, IOException { startServer(true); scribe.Client client = makeNiftyPlaintextClient(); client.Log(Arrays.asList(new LogEntry("client2", "aaa"))); client.Log(Arrays.asList(new LogEntry("client2", "bbb"))); client.Log(Arrays.asList(new LogEntry("client2", "ccc"))); } private KeyManager[] getClientKeyManagers() throws SSLException { try { KeyStore keyStore = KeyStore.getInstance("PKCS12"); try (InputStream keyInput = new FileInputStream(getClientPKCS12File())) { keyStore.load(keyInput, CLIENT_PKCS12_PASSWORD.toCharArray()); } KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance( KeyManagerFactory.getDefaultAlgorithm()); keyManagerFactory.init(keyStore, CLIENT_PKCS12_PASSWORD.toCharArray()); return keyManagerFactory.getKeyManagers(); } catch (Exception e) { throw new SSLException(e); } } private void startRawSSLClient(long delay) throws SSLException { try { KeyStore keyStore = KeyStore.getInstance("JKS"); keyStore.load(null, null); CertificateFactory cf = CertificateFactory.getInstance("X.509"); X509Certificate cert = (X509Certificate) cf.generateCertificate(new FileInputStream(getServerCertFile())); X500Principal principal = cert.getSubjectX500Principal(); keyStore.setCertificateEntry(principal.getName("RFC2253"), cert); TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); trustManagerFactory.init(keyStore); KeyManager[] clientKeyManagers = getClientKeyManagers(); SSLContext context = SSLContext.getInstance("TLS"); context.init(clientKeyManagers, trustManagerFactory.getTrustManagers(), null); Socket sock = new Socket(); sock.connect(new InetSocketAddress("localhost", port)); if (delay != 0) { Thread.sleep(delay); } SSLSocket sslSocket = (SSLSocket) context.getSocketFactory().createSocket(sock, "localhost", port, true); sslSocket.startHandshake(); SSLSession session = sslSocket.getSession(); Assert.assertTrue(session.isValid()); sslSocket.close(); } catch (Throwable t) { throw new SSLException(t); } } @Test public void testDefaultServerWithClientCert() throws InterruptedException, IOException, TException { SslServerConfiguration serverConfig = OpenSslServerConfiguration.newBuilder() .certFile(getServerCertFile()) .keyFile(getPrivateKeyFile()) .allowPlaintext(false) .clientCAFile(getClientCertFile()) .build(); ThriftServerDefBuilder builder = getThriftServerDefBuilder(serverConfig, null); startServer(builder); scribe.Client client1 = makeNiftyClient(getClientSSLConfiguration(null, getClientKeyManagers())); Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "aaa"))), ResultCode.OK); } @Test public void testOptionalClientAuthenticatingServer() throws InterruptedException, IOException, TException { SslServerConfiguration serverConfig = OpenSslServerConfiguration.newBuilder() .certFile(getServerCertFile()) .keyFile(getPrivateKeyFile()) .allowPlaintext(false) .sslVerification(OpenSslServerConfiguration.SSLVerification.VERIFY_OPTIONAL) .clientCAFile(getClientCertFile()) .build(); ThriftServerDefBuilder builder = getThriftServerDefBuilder(serverConfig, null); startServer(builder); scribe.Client client1 = makeNiftyClient(getClientSSLConfiguration(null, getClientKeyManagers())); Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "aaa"))), ResultCode.OK); scribe.Client client2 = makeNiftyClient(getClientSSLConfiguration()); Assert.assertEquals(client2.Log(Arrays.asList(new LogEntry("client2", "aaa"))), ResultCode.OK); } @Test public void testClientAuthenticatingServer() throws InterruptedException, IOException, TException { SslServerConfiguration serverConfig = OpenSslServerConfiguration.newBuilder() .certFile(getServerCertFile()) .keyFile(getPrivateKeyFile()) .allowPlaintext(false) .sslVerification(OpenSslServerConfiguration.SSLVerification.VERIFY_REQUIRE) .clientCAFile(getClientCertFile()) .build(); ThriftServerDefBuilder builder = getThriftServerDefBuilder(serverConfig, null); startServer(builder); scribe.Client client1 = makeNiftyClient(getClientSSLConfiguration(null, getClientKeyManagers())); Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "aaa"))), ResultCode.OK); } @Test public void testClientAuthenticatingServerAllowPlaintext() throws InterruptedException, IOException, TException { SslServerConfiguration serverConfig = OpenSslServerConfiguration.newBuilder() .certFile(getServerCertFile()) .keyFile(getPrivateKeyFile()) .allowPlaintext(true) .sslVerification(OpenSslServerConfiguration.SSLVerification.VERIFY_REQUIRE) .clientCAFile(getClientCertFile()) .build(); ThriftServerDefBuilder builder = getThriftServerDefBuilder(serverConfig, null); startServer(builder); scribe.Client client1 = makeNiftyClient(getClientSSLConfiguration(null, getClientKeyManagers())); Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "aaa"))), ResultCode.OK); scribe.Client client2 = makeNiftyPlaintextClient(); Assert.assertEquals(client2.Log(Arrays.asList(new LogEntry("client2", "aaa"))), ResultCode.OK); } @Test public void testThreadLocalSslBufferPool() throws InterruptedException, IOException, TException { SslServerConfiguration serverConfig = OpenSslServerConfiguration.newBuilder() .certFile(getServerCertFile()) .keyFile(getPrivateKeyFile()) .allowPlaintext(false) .sslVerification(OpenSslServerConfiguration.SSLVerification.VERIFY_REQUIRE) .clientCAFile(getClientCertFile()) .threadLocalSslBuffer(true) .build(); ThriftServerDefBuilder builder = getThriftServerDefBuilder(serverConfig, null); startServer(builder); scribe.Client client = makeNiftyClient(getClientSSLConfiguration(null, getClientKeyManagers())); Assert.assertEquals(client.Log(Arrays.asList(new LogEntry("client", "aaa"))), ResultCode.OK); } @Test(expectedExceptions = TTransportException.class) public void testClientWithoutCerts() throws InterruptedException, IOException, TException { SslServerConfiguration serverConfig = OpenSslServerConfiguration.newBuilder() .certFile(getServerCertFile()) .keyFile(getPrivateKeyFile()) .allowPlaintext(false) .sslVerification(OpenSslServerConfiguration.SSLVerification.VERIFY_REQUIRE) .clientCAFile(getClientCertFile()) .build(); startServer(getThriftServerDefBuilder(serverConfig, null)); scribe.Client client = makeNiftyClient(getClientSSLConfiguration()); client.Log(Arrays.asList(new LogEntry("client", "aaa"))); } @Test(expectedExceptions = SSLException.class) public void testWithServerIdleTimeout() throws TException, InterruptedException, IOException, NoSuchAlgorithmException { startServer(getThriftServerDefBuilder(createSSLServerConfiguration(false, null), null) .clientIdleTimeout(Duration.succinctDuration(1, TimeUnit.MILLISECONDS))); startRawSSLClient(200); } @Test(expectedExceptions = SSLException.class) public void testWithServerIdleTimeoutAllowPlaintext() throws TException, InterruptedException, IOException, NoSuchAlgorithmException { startServer(getThriftServerDefBuilder(createSSLServerConfiguration(true, null), null) .clientIdleTimeout(Duration.succinctDuration(1, TimeUnit.MILLISECONDS))); startRawSSLClient(200); } @Test(expectedExceptions = TApplicationException.class, expectedExceptionsMessageRegExp = "Internal error processing Log") public void testPlaintextServerThrowsException() throws InterruptedException, IOException, TException { startServer(getThriftServerDefBuilder( createSSLServerConfiguration(true /* allowPlaintext */, null), null, (List<LogEntry> messages) -> { throw new RuntimeException("Error"); })); scribe.Client client = makeNiftyPlaintextClient(); client.Log(Arrays.asList(new LogEntry("client", "aaa"))); } @Test(expectedExceptions = TApplicationException.class, expectedExceptionsMessageRegExp = "Internal error processing Log") public void testDefaultServerThrowsException() throws InterruptedException, IOException, TException { startServer(getThriftServerDefBuilder( createSSLServerConfiguration(false, null), null, (List<LogEntry> messages) -> { throw new RuntimeException("Error"); })); scribe.Client client = makeNiftyClient(getClientSSLConfiguration()); client.Log(Arrays.asList(new LogEntry("client", "aaa"))); } @Test(expectedExceptions = TApplicationException.class, expectedExceptionsMessageRegExp = "Internal error processing Log") public void testClientAuthenticatingServerThrowsException() throws InterruptedException, IOException, TException { SslServerConfiguration serverConfig = OpenSslServerConfiguration.newBuilder() .certFile(getServerCertFile()) .keyFile(getPrivateKeyFile()) .allowPlaintext(false) .sslVerification(OpenSslServerConfiguration.SSLVerification.VERIFY_REQUIRE) .clientCAFile(getClientCertFile()) .build(); startServer(getThriftServerDefBuilder( serverConfig, null, (List<LogEntry> messages) -> { throw new RuntimeException("Error"); })); scribe.Client client = makeNiftyClient(getClientSSLConfiguration(null, getClientKeyManagers())); client.Log(Arrays.asList(new LogEntry("client", "aaa"))); } @Test public void testSSLSessionResumption() throws Exception { // Ticket resumes are not supported by nifty client, so we test stateful session resumption // only. SessionTicketKey[] keys = { createSessionTicketKey() }; SslServerConfiguration sslServerConfiguration = createSSLServerConfiguration(true, keys); startServer(getThriftServerDefBuilder(sslServerConfiguration, null)); SslClientConfiguration sslClientConfiguration = getClientSSLConfiguration(); scribe.Client client1 = makeNiftyClient(sslClientConfiguration); client1.Log(Arrays.asList(new LogEntry("client1", "aaa"))); Assert.assertFalse(isSessionResumed(getSSLSession(client1))); scribe.Client client2 = makeNiftyClient(sslClientConfiguration); client2.Log(Arrays.asList(new LogEntry("client2", "aaa"))); Assert.assertTrue(isSessionResumed(getSSLSession(client2))); client2.Log(Arrays.asList(new LogEntry("client2", "bbb"))); Assert.assertTrue(isSessionResumed(getSSLSession(client2))); SessionTicketKey[] keys2 = { createSessionTicketKey() }; SslServerConfiguration sslServerConfiguration2 = createSSLServerConfiguration(true, keys2); server.updateSSLConfiguration(sslServerConfiguration2); scribe.Client client3 = makeNiftyClient(sslClientConfiguration); client3.Log(Arrays.asList(new LogEntry("client3", "aaa"))); Assert.assertFalse(isSessionResumed(getSSLSession(client3))); scribe.Client client4 = makeNiftyClient(sslClientConfiguration); client4.Log(Arrays.asList(new LogEntry("client4", "aaa"))); Assert.assertTrue(isSessionResumed(getSSLSession(client4))); } class TestConfigUpdater implements TransportAttachObserver { public NettyServerTransport attachedTransport; @Override public void attachTransport(NettyServerTransport transport) { attachedTransport = transport; } @Override public void detachTransport() { attachedTransport = null; } void updateSSLConfig(SslServerConfiguration newConfig) { attachedTransport.updateSSLConfiguration(newConfig); } }; @Test public void testAttachTransportToUpdater() throws InterruptedException, IOException { TestConfigUpdater configUpdater = new TestConfigUpdater(); SessionTicketKey[] keys = { createSessionTicketKey() }; SslServerConfiguration sslServerConfiguration = createSSLServerConfiguration(true, keys); startServer(getThriftServerDefBuilder(sslServerConfiguration, configUpdater)); Assert.assertNotNull(configUpdater.attachedTransport); SessionTicketKey[] newKeys = { createSessionTicketKey() }; SslServerConfiguration newConfig = createSSLServerConfiguration(true, newKeys); configUpdater.updateSSLConfig(newConfig); server.stop(); server = null; Assert.assertNull(configUpdater.attachedTransport); } @Test public void testRotateTicketSeedFile() throws InterruptedException, IOException { startServer(); OpenSslServerConfiguration config = (OpenSslServerConfiguration) server.getSSLConfiguration(); List<SessionTicketKey> actual = ImmutableList.copyOf(config.ticketKeys); List<SessionTicketKey> expected = new TicketSeedFileParser().parse(getTicketSeedFile()); assertTicketKeysEqual(actual, expected); // Rotate the ticket seeds file long callbacksSucceeded = fileWatcher.getStats().getCallbacksSucceeded(); updateTicketSeedFile(getResourceFileContents(Plain.class, "/ticket_seeds2.json")); while (fileWatcher.getStats().getCallbacksSucceeded() < callbacksSucceeded + 1) { Thread.sleep(25); } config = (OpenSslServerConfiguration) server.getSSLConfiguration(); List<SessionTicketKey> actual2 = ImmutableList.copyOf(config.ticketKeys); List<SessionTicketKey> expected2 = new TicketSeedFileParser().parse(getTicketSeedFile()); assertTicketKeysEqual(actual2, expected2); // Make sure the keys actually changed ... Assert.assertNotEquals(actual.get(0).getName(), actual2.get(0).getName()); } @Test public void testRotateSSLKeyAndCertFiles() throws InterruptedException, IOException, TException { startServer(); // This client config is using the original cert that the server starts up with SslClientConfiguration config1 = getClientSSLConfiguration(getResourceFile(Plain.class, "/rsa.crt")); // This client config is using the cert that we change to halfway through this test SslClientConfiguration config2 = getClientSSLConfiguration(getResourceFile(Plain.class, "/rsa2.crt")); scribe.Client client1 = makeNiftyClient(config1); scribe.Client client2 = makeNiftyClient(config2); Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "aaa"))), ResultCode.OK); // Before the server cert is rotated, using it on the client should fail try { client2.Log(Arrays.asList(new LogEntry("client2", "aaa"))); Assert.fail("Request with wrong certificate should have thrown an exception"); } catch (TTransportException e) { // The error is expected } // Rotate the cert and private key files long callbacksSucceeded = fileWatcher.getStats().getCallbacksSucceeded(); updateServerCertFile(getResourceFileContents(Plain.class, "/rsa2.crt")); updatePrivateKeyFile(getResourceFileContents(Plain.class, "/rsa2.key")); while (fileWatcher.getStats().getCallbacksSucceeded() < callbacksSucceeded + 1) { Thread.sleep(25); } // Need to re-create clients to get their connections to use the new server cert. client1 = makeNiftyClient(config1); client2 = makeNiftyClient(config2); // After the server cert is rotated, using the original cert on the client should fail try { client1.Log(Arrays.asList(new LogEntry("client1", "bbb"))); Assert.fail("Request with wrong certificate should have thrown an exception"); } catch (TTransportException e) { // The error is expected } Assert.assertEquals(client2.Log(Arrays.asList(new LogEntry("client2", "bbb"))), ResultCode.OK); } private static SessionTicketKey createSessionTicketKey() { SecureRandom secureRandom = new SecureRandom(); byte[] name = new byte[SessionTicketKey.NAME_SIZE]; byte[] hmac = new byte[SessionTicketKey.HMAC_KEY_SIZE]; byte[] aes = new byte[SessionTicketKey.AES_KEY_SIZE]; secureRandom.nextBytes(name); secureRandom.nextBytes(hmac); secureRandom.nextBytes(aes); return new SessionTicketKey(name, hmac, aes); } private static SSLSession getSSLSession(scribe.Client client) { TNiftyClientChannelTransport clientTransport = (TNiftyClientChannelTransport) client.getInputProtocol().getTransport(); SslHandler sslHandler = (SslHandler) clientTransport.getChannel().getNettyChannel().getPipeline().get("ssl"); return sslHandler.getEngine().getSession(); } private static boolean isSessionResumed(SSLSession sslSession) throws NoSuchFieldException, IllegalAccessException { Field sslResumedField = sslSession.getClass().getDeclaredField("isSessionResumption"); sslResumedField.setAccessible(true); return sslResumedField.getBoolean(sslSession); } }