/*
* Copyright (C) 2012-2016 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.nifty.server;
import com.facebook.nifty.client.FramedClientConnector;
import com.facebook.nifty.client.NettyClientConfig;
import com.facebook.nifty.client.NiftyClient;
import com.facebook.nifty.client.TNiftyClientChannelTransport;
import com.facebook.nifty.core.NettyServerConfig;
import com.facebook.nifty.core.NettyServerTransport;
import com.facebook.nifty.core.RequestContext;
import com.facebook.nifty.core.RequestContexts;
import com.facebook.nifty.core.ThriftServerDefBuilder;
import com.facebook.nifty.ssl.OpenSslServerConfiguration;
import com.facebook.nifty.ssl.PollingMultiFileWatcher;
import com.facebook.nifty.ssl.SslClientConfiguration;
import com.facebook.nifty.ssl.SslConfigFileWatcher;
import com.facebook.nifty.ssl.SslServerConfiguration;
import com.facebook.nifty.ssl.TicketSeedFileParser;
import com.facebook.nifty.ssl.TransportAttachObserver;
import com.facebook.nifty.test.LogEntry;
import com.facebook.nifty.test.ResultCode;
import com.facebook.nifty.test.scribe;
import com.google.common.collect.ImmutableList;
import com.google.common.io.Files;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import org.apache.thrift.TApplicationException;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.apache.tomcat.jni.SessionTicketKey;
import org.jboss.netty.channel.group.DefaultChannelGroup;
import org.jboss.netty.handler.ssl.HackyJdkSslClientContext;
import org.jboss.netty.handler.ssl.SslContext;
import org.jboss.netty.handler.ssl.SslHandler;
import org.testng.Assert;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.TrustManagerFactory;
import javax.security.auth.x500.X500Principal;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.security.KeyStore;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import static java.util.Objects.requireNonNull;
public class TestNiftyOpenSslServer
{
private static final Logger log = Logger.get(TestNiftyOpenSslServer.class);
private NettyServerTransport server;
private int port;
private PollingMultiFileWatcher fileWatcher = null;
// Server-side configs
private File ticketSeedFile = null;
private File privateKeyFile = null;
private File serverCertFile = null;
// Client-side configs
private File clientCertFile = null;
private File clientPKCS12File = null;
// Password provided to the openssl command line tool when creating the client.pkcs12 file
private static final String CLIENT_PKCS12_PASSWORD = "12345";
@BeforeMethod(alwaysRun = true)
public void setup()
{
server = null;
fileWatcher = new PollingMultiFileWatcher(Duration.valueOf("0 ms"), Duration.valueOf("100 ms"));
}
@AfterMethod(alwaysRun = true)
public void teardown()
throws InterruptedException
{
if (server != null) {
server.stop();
}
fileWatcher = null;
deleteFilesIfExistIgnoreErrors(
ticketSeedFile,
privateKeyFile,
serverCertFile,
clientCertFile,
clientPKCS12File);
ticketSeedFile = privateKeyFile = serverCertFile = clientCertFile = clientPKCS12File = null;
}
private void startServer() {
startServer(false);
}
private void startServer(boolean allowPlaintext)
{
try {
List<SessionTicketKey> ticketKeysList = new TicketSeedFileParser().parse(getTicketSeedFile());
SessionTicketKey[] ticketKeys = ticketKeysList.toArray(new SessionTicketKey[ticketKeysList.size()]);
SslConfigFileWatcher configUpdater = new SslConfigFileWatcher(
getTicketSeedFile(),
getPrivateKeyFile(),
getServerCertFile(),
null,
fileWatcher);
SslServerConfiguration config = createSSLServerConfiguration(allowPlaintext, ticketKeys);
long callbacksSucceeded = fileWatcher.getStats().getCallbacksSucceeded();
startServer(getThriftServerDefBuilder(config, configUpdater));
while (fileWatcher.getStats().getCallbacksSucceeded() < callbacksSucceeded + 1) {
Thread.sleep(25); // Wait for first callback to process
}
} catch (InterruptedException | IOException e) {
throw new RuntimeException(e);
}
}
private void startServer(final ThriftServerDefBuilder thriftServerDefBuilder)
{
server = new NettyServerTransport(thriftServerDefBuilder.build(),
NettyServerConfig.newBuilder().build(),
new DefaultChannelGroup());
server.start();
port = ((InetSocketAddress)server.getServerChannel().getLocalAddress()).getPort();
}
SslServerConfiguration createSSLServerConfiguration(boolean allowPlaintext,
SessionTicketKey[] ticketKeys) throws IOException {
return OpenSslServerConfiguration.newBuilder()
.certFile(getServerCertFile())
.keyFile(getPrivateKeyFile())
.allowPlaintext(allowPlaintext)
.ticketKeys(ticketKeys)
.build();
}
private ThriftServerDefBuilder getThriftServerDefBuilder(
SslServerConfiguration sslServerConfiguration,
TransportAttachObserver configUpdater) {
return getThriftServerDefBuilder(sslServerConfiguration, configUpdater, (List<LogEntry> entries) -> ResultCode.OK);
}
private ThriftServerDefBuilder getThriftServerDefBuilder(
SslServerConfiguration sslServerConfiguration,
TransportAttachObserver configUpdater,
final Function<List<LogEntry>, ResultCode> thriftHandler)
{
requireNonNull(thriftHandler);
return new ThriftServerDefBuilder()
.listen(0)
.withSSLConfiguration(sslServerConfiguration)
.withTransportAttachObserver(configUpdater)
.withProcessor(new scribe.Processor<>(new scribe.Iface() {
@Override
public ResultCode Log(List<LogEntry> messages) throws TException {
RequestContext context = RequestContexts.getCurrentContext();
for (LogEntry message : messages) {
log.info("[Client: %s] %s: %s",
context.getConnectionContext().getRemoteAddress(),
message.getCategory(),
message.getMessage());
}
try {
return thriftHandler.apply(messages);
} catch (Exception e) {
throw new TException(e);
}
}
}));
}
private SslClientConfiguration getClientSSLConfiguration() throws IOException {
return getClientSSLConfiguration(null);
}
private SslClientConfiguration getClientSSLConfiguration(File certFile) throws IOException {
return getClientSSLConfiguration(certFile, null);
}
private SslClientConfiguration getClientSSLConfiguration(File certFile, KeyManager[] keyManagers) throws IOException {
SslContext context = new HackyJdkSslClientContext(
null,
certFile == null ? getServerCertFile() : certFile,
keyManagers,
null,
null,
null,
10000,
10000
);
return new SslClientConfiguration.Builder().sslContext(context).build();
}
private scribe.Client makeNiftyClient(SslClientConfiguration clientSSLConfiguration)
throws TTransportException, InterruptedException
{
NettyClientConfig config =
NettyClientConfig.newBuilder()
.setSSLClientConfiguration(clientSSLConfiguration).build();
InetSocketAddress address = new InetSocketAddress("localhost", port);
TTransport transport = new NiftyClient(config)
.connectSync(scribe.Client.class, new FramedClientConnector(address));
TProtocol protocol = new TBinaryProtocol(transport);
return new scribe.Client(protocol);
}
private scribe.Client makeNiftyPlaintextClient()
throws TTransportException, InterruptedException
{
NettyClientConfig config =
NettyClientConfig.newBuilder().build();
InetSocketAddress address = new InetSocketAddress("localhost", port);
TTransport transport = new NiftyClient(config)
.connectSync(scribe.Client.class, new FramedClientConnector(address));
TProtocol protocol = new TBinaryProtocol(transport);
return new scribe.Client(protocol);
}
/**
* Returns a file path to the given resource loaded using the given class's class loader.
*
* @param clazz the class whose class loader should be used to load the resource.
* @param resourcePath the resource path.
* @return a File object representing the path to the resource.
*/
private File getResourceFile(Class<?> clazz, String resourcePath) {
return new File(clazz.getResource(resourcePath).getFile());
}
/**
* Returns the contents of the given resource loaded using the given class's class loader.
*
* @param clazz the class whose class loader should be used to load the resource.
* @param resourcePath the resource path.
* @return the contents of the resource file.
* @throws IOException if the resource file could not be read.
*/
private byte[] getResourceFileContents(Class<?> clazz, String resourcePath) throws IOException {
return Files.toByteArray(getResourceFile(clazz, resourcePath));
}
/**
* Overwrites the contents of the given file with the given byte array. If the file does not exist, it will
* be created.
*
* @param file the file to overwrite.
* @param newContents new file contents.
* @throws IOException if the write fails.
*/
private void overwriteFile(File file, byte[] newContents) throws IOException {
java.nio.file.Files.write(file.toPath(), newContents);
}
/**
* Best-effort attempt to delete all of the given files if they exist. Ignores errors.
*
* @param files the files to delete.
*/
private void deleteFilesIfExistIgnoreErrors(File... files) {
for (File file : files) {
if (file != null) {
try {
java.nio.file.Files.deleteIfExists(file.toPath());
} catch (IOException e) {
// silently ignore delete errors
}
}
}
}
/**
* Creates a temp file with the same contents as the given resource. Returns the path to the temp file.
* The temp file should be deleted by the user when the test finishes.
*
* @param clazz the class whose class loader should be used to load the resource.
* @param resourcePath the resource path.
* @return a File object representing the path to the new temp file.
* @throws IOException if the resource file could not be read, or temp file could not be created or written.
*/
private File initTempFileFromResource(Class<?> clazz, String resourcePath) throws IOException {
File result = File.createTempFile("test_nifty_openssl_server", resourcePath.replaceAll("/", "_"));
overwriteFile(result, getResourceFileContents(clazz, resourcePath));
return result;
}
/**
* Returns the path to a temporary ticket seed file. If the temp file does not yet exist, it is created on
* demand and initialized with the contents of the "/ticket_seeds.json" resource.
* The temp file should be deleted by the user when the test finishes.
*
* @return the new file.
* @throws IOException if reading the resource or creating the temp file fails.
*/
private File getTicketSeedFile() throws IOException {
if (ticketSeedFile == null) {
ticketSeedFile = initTempFileFromResource(Plain.class, "/ticket_seeds.json");
}
return ticketSeedFile;
}
/**
* Overwrites the contents of the ticket seed file with the given byte array.
*
* @param newContents new ticket seed file contents.
* @throws IOException if writing the file fails.
*/
private void updateTicketSeedFile(byte[] newContents) throws IOException {
overwriteFile(getTicketSeedFile(), newContents);
}
/**
* Returns the path to a temporary private key file. If the temp file does not yet exist, it is created on
* demand and initialized with the contents of the "/rsa.key" resource.
* The temp file should be deleted by the user when the test finishes.
*
* @return the new file.
* @throws IOException if reading the resource or creating the temp file fails.
*/
private File getPrivateKeyFile() throws IOException {
if (privateKeyFile == null) {
privateKeyFile = initTempFileFromResource(Plain.class, "/rsa.key");
}
return privateKeyFile;
}
/**
* Overwrites the contents of the private key file with the given byte array.
*
* @param newContents new private key file contents.
* @throws IOException if writing the file fails.
*/
private void updatePrivateKeyFile(byte[] newContents) throws IOException {
overwriteFile(getPrivateKeyFile(), newContents);
}
/**
* Returns the path to a temporary server certificate file. If the temp file does not yet exist,
* it is created on demand and initialized with the contents of the "/rsa.crt" resource.
* The temp file should be deleted by the user when the test finishes.
*
* @return the new file.
* @throws IOException if reading the resource or creating the temp file fails.
*/
private File getServerCertFile() throws IOException {
if (serverCertFile == null) {
serverCertFile = initTempFileFromResource(Plain.class, "/rsa.crt");
}
return serverCertFile;
}
/**
* Overwrites the contents of the server certificate file with the given byte array.
*
* @param newContents new certificate file contents.
* @throws IOException if writing the file fails.
*/
private void updateServerCertFile(byte[] newContents) throws IOException {
overwriteFile(getServerCertFile(), newContents);
}
/**
* Returns the path to a temporary client certificate file. If the temp file does not yet exist,
* it is created on demand and initialized with the contents of the "/client.crt" resource.
* The temp file should be deleted by the user when the test finishes.
*
* @return the new file.
* @throws IOException if reading the resource or creating the temp file fails.
*/
private File getClientCertFile() throws IOException {
if (clientCertFile == null) {
clientCertFile = initTempFileFromResource(Plain.class, "/client.crt");
}
return clientCertFile;
}
/**
* Overwrites the contents of the certificate file with the given byte array.
*
* @param newContents new certificate file contents.
* @throws IOException if writing the file fails.
*/
private void updateClientCertFile(byte[] newContents) throws IOException {
overwriteFile(getClientCertFile(), newContents);
}
/**
* Returns the path to a temporary client PKCS12 key file. If the temp file does not yet exist,
* it is created ondemand and initialized with the contents of the "/client.pkcs12" resource.
* The temp file should be deleted by the user when the test finishes.
*
* @return the new file.
* @throws IOException if reading the resource or creating the temp file fails.
*/
private File getClientPKCS12File() throws IOException {
if (clientPKCS12File == null) {
clientPKCS12File = initTempFileFromResource(Plain.class, "/client.pkcs12");
}
return clientPKCS12File;
}
/**
* Overwrites the contents of the client PKCS12 key file with the given byte array.
*
* @param newContents new certificate file contents.
* @throws IOException if writing the file fails.
*/
private void updateClientPKCS12File(byte[] newContents) throws IOException {
overwriteFile(getClientPKCS12File(), newContents);
}
/**
* Asserts that the given lists of session ticket keys are the same. {@link SessionTicketKey} seems to not
* implement a proper equals() method so we have to do this the hard way.
*
* @param actualKeys the actual ticket keys.
* @param expectedKeys the expected ticket keys.
*/
private void assertTicketKeysEqual(List<SessionTicketKey> actualKeys, List<SessionTicketKey> expectedKeys) {
Assert.assertEquals(actualKeys.size(), expectedKeys.size());
for (int i = 0; i < actualKeys.size(); ++i) {
SessionTicketKey actualKey = actualKeys.get(i);
SessionTicketKey expectedKey = expectedKeys.get(i);
Assert.assertEquals(actualKey.getAesKey(), expectedKey.getAesKey());
Assert.assertEquals(actualKey.getHmacKey(), expectedKey.getHmacKey());
Assert.assertEquals(actualKey.getName(), expectedKey.getName());
}
}
@Test
public void testSSL() throws InterruptedException, TException, IOException
{
startServer();
scribe.Client client1 = makeNiftyClient(getClientSSLConfiguration());
Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "aaa"))), ResultCode.OK);
Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "bbb"))), ResultCode.OK);
scribe.Client client2 = makeNiftyClient(getClientSSLConfiguration());
Assert.assertEquals(client2.Log(Arrays.asList(new LogEntry("client2", "ccc"))), ResultCode.OK);
}
@Test
public void testSSLWithPlaintextAllowedServer() throws InterruptedException, TException, IOException
{
startServer(true);
scribe.Client client1 = makeNiftyClient(getClientSSLConfiguration());
Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "aaa"))), ResultCode.OK);
Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "bbb"))), ResultCode.OK);
scribe.Client client2 = makeNiftyClient(getClientSSLConfiguration());
Assert.assertEquals(client2.Log(Arrays.asList(new LogEntry("client2", "ccc"))), ResultCode.OK);
}
@Test(expectedExceptions = TTransportException.class)
public void testUnencryptedClient() throws InterruptedException, TException
{
startServer();
scribe.Client client = makeNiftyPlaintextClient();
client.Log(Arrays.asList(new LogEntry("client2", "aaa")));
client.Log(Arrays.asList(new LogEntry("client2", "bbb")));
client.Log(Arrays.asList(new LogEntry("client2", "ccc")));
}
@Test
public void testUnencryptedClientWithAllowPlaintextServer() throws InterruptedException, TException, IOException
{
startServer(true);
scribe.Client client = makeNiftyPlaintextClient();
client.Log(Arrays.asList(new LogEntry("client2", "aaa")));
client.Log(Arrays.asList(new LogEntry("client2", "bbb")));
client.Log(Arrays.asList(new LogEntry("client2", "ccc")));
}
private KeyManager[] getClientKeyManagers() throws SSLException {
try {
KeyStore keyStore = KeyStore.getInstance("PKCS12");
try (InputStream keyInput = new FileInputStream(getClientPKCS12File())) {
keyStore.load(keyInput, CLIENT_PKCS12_PASSWORD.toCharArray());
}
KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(
KeyManagerFactory.getDefaultAlgorithm());
keyManagerFactory.init(keyStore, CLIENT_PKCS12_PASSWORD.toCharArray());
return keyManagerFactory.getKeyManagers();
} catch (Exception e) {
throw new SSLException(e);
}
}
private void startRawSSLClient(long delay) throws SSLException {
try {
KeyStore keyStore = KeyStore.getInstance("JKS");
keyStore.load(null, null);
CertificateFactory cf = CertificateFactory.getInstance("X.509");
X509Certificate cert = (X509Certificate) cf.generateCertificate(new FileInputStream(getServerCertFile()));
X500Principal principal = cert.getSubjectX500Principal();
keyStore.setCertificateEntry(principal.getName("RFC2253"), cert);
TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(keyStore);
KeyManager[] clientKeyManagers = getClientKeyManagers();
SSLContext context = SSLContext.getInstance("TLS");
context.init(clientKeyManagers, trustManagerFactory.getTrustManagers(), null);
Socket sock = new Socket();
sock.connect(new InetSocketAddress("localhost", port));
if (delay != 0) {
Thread.sleep(delay);
}
SSLSocket sslSocket = (SSLSocket) context.getSocketFactory().createSocket(sock, "localhost", port, true);
sslSocket.startHandshake();
SSLSession session = sslSocket.getSession();
Assert.assertTrue(session.isValid());
sslSocket.close();
} catch (Throwable t) {
throw new SSLException(t);
}
}
@Test
public void testDefaultServerWithClientCert() throws InterruptedException, IOException, TException {
SslServerConfiguration serverConfig = OpenSslServerConfiguration.newBuilder()
.certFile(getServerCertFile())
.keyFile(getPrivateKeyFile())
.allowPlaintext(false)
.clientCAFile(getClientCertFile())
.build();
ThriftServerDefBuilder builder = getThriftServerDefBuilder(serverConfig, null);
startServer(builder);
scribe.Client client1 = makeNiftyClient(getClientSSLConfiguration(null, getClientKeyManagers()));
Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "aaa"))), ResultCode.OK);
}
@Test
public void testOptionalClientAuthenticatingServer() throws InterruptedException, IOException, TException {
SslServerConfiguration serverConfig = OpenSslServerConfiguration.newBuilder()
.certFile(getServerCertFile())
.keyFile(getPrivateKeyFile())
.allowPlaintext(false)
.sslVerification(OpenSslServerConfiguration.SSLVerification.VERIFY_OPTIONAL)
.clientCAFile(getClientCertFile())
.build();
ThriftServerDefBuilder builder = getThriftServerDefBuilder(serverConfig, null);
startServer(builder);
scribe.Client client1 = makeNiftyClient(getClientSSLConfiguration(null, getClientKeyManagers()));
Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "aaa"))), ResultCode.OK);
scribe.Client client2 = makeNiftyClient(getClientSSLConfiguration());
Assert.assertEquals(client2.Log(Arrays.asList(new LogEntry("client2", "aaa"))), ResultCode.OK);
}
@Test
public void testClientAuthenticatingServer() throws InterruptedException, IOException, TException {
SslServerConfiguration serverConfig = OpenSslServerConfiguration.newBuilder()
.certFile(getServerCertFile())
.keyFile(getPrivateKeyFile())
.allowPlaintext(false)
.sslVerification(OpenSslServerConfiguration.SSLVerification.VERIFY_REQUIRE)
.clientCAFile(getClientCertFile())
.build();
ThriftServerDefBuilder builder = getThriftServerDefBuilder(serverConfig, null);
startServer(builder);
scribe.Client client1 = makeNiftyClient(getClientSSLConfiguration(null, getClientKeyManagers()));
Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "aaa"))), ResultCode.OK);
}
@Test
public void testClientAuthenticatingServerAllowPlaintext() throws InterruptedException, IOException, TException {
SslServerConfiguration serverConfig = OpenSslServerConfiguration.newBuilder()
.certFile(getServerCertFile())
.keyFile(getPrivateKeyFile())
.allowPlaintext(true)
.sslVerification(OpenSslServerConfiguration.SSLVerification.VERIFY_REQUIRE)
.clientCAFile(getClientCertFile())
.build();
ThriftServerDefBuilder builder = getThriftServerDefBuilder(serverConfig, null);
startServer(builder);
scribe.Client client1 = makeNiftyClient(getClientSSLConfiguration(null, getClientKeyManagers()));
Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "aaa"))), ResultCode.OK);
scribe.Client client2 = makeNiftyPlaintextClient();
Assert.assertEquals(client2.Log(Arrays.asList(new LogEntry("client2", "aaa"))), ResultCode.OK);
}
@Test
public void testThreadLocalSslBufferPool() throws InterruptedException, IOException, TException {
SslServerConfiguration serverConfig = OpenSslServerConfiguration.newBuilder()
.certFile(getServerCertFile())
.keyFile(getPrivateKeyFile())
.allowPlaintext(false)
.sslVerification(OpenSslServerConfiguration.SSLVerification.VERIFY_REQUIRE)
.clientCAFile(getClientCertFile())
.threadLocalSslBuffer(true)
.build();
ThriftServerDefBuilder builder = getThriftServerDefBuilder(serverConfig, null);
startServer(builder);
scribe.Client client = makeNiftyClient(getClientSSLConfiguration(null, getClientKeyManagers()));
Assert.assertEquals(client.Log(Arrays.asList(new LogEntry("client", "aaa"))), ResultCode.OK);
}
@Test(expectedExceptions = TTransportException.class)
public void testClientWithoutCerts() throws InterruptedException, IOException, TException {
SslServerConfiguration serverConfig = OpenSslServerConfiguration.newBuilder()
.certFile(getServerCertFile())
.keyFile(getPrivateKeyFile())
.allowPlaintext(false)
.sslVerification(OpenSslServerConfiguration.SSLVerification.VERIFY_REQUIRE)
.clientCAFile(getClientCertFile())
.build();
startServer(getThriftServerDefBuilder(serverConfig, null));
scribe.Client client = makeNiftyClient(getClientSSLConfiguration());
client.Log(Arrays.asList(new LogEntry("client", "aaa")));
}
@Test(expectedExceptions = SSLException.class)
public void testWithServerIdleTimeout()
throws TException, InterruptedException, IOException, NoSuchAlgorithmException {
startServer(getThriftServerDefBuilder(createSSLServerConfiguration(false, null), null)
.clientIdleTimeout(Duration.succinctDuration(1, TimeUnit.MILLISECONDS)));
startRawSSLClient(200);
}
@Test(expectedExceptions = SSLException.class)
public void testWithServerIdleTimeoutAllowPlaintext()
throws TException, InterruptedException, IOException, NoSuchAlgorithmException {
startServer(getThriftServerDefBuilder(createSSLServerConfiguration(true, null), null)
.clientIdleTimeout(Duration.succinctDuration(1, TimeUnit.MILLISECONDS)));
startRawSSLClient(200);
}
@Test(expectedExceptions = TApplicationException.class,
expectedExceptionsMessageRegExp = "Internal error processing Log")
public void testPlaintextServerThrowsException() throws InterruptedException, IOException, TException {
startServer(getThriftServerDefBuilder(
createSSLServerConfiguration(true /* allowPlaintext */, null),
null,
(List<LogEntry> messages) -> { throw new RuntimeException("Error"); }));
scribe.Client client = makeNiftyPlaintextClient();
client.Log(Arrays.asList(new LogEntry("client", "aaa")));
}
@Test(expectedExceptions = TApplicationException.class,
expectedExceptionsMessageRegExp = "Internal error processing Log")
public void testDefaultServerThrowsException() throws InterruptedException, IOException, TException {
startServer(getThriftServerDefBuilder(
createSSLServerConfiguration(false, null),
null,
(List<LogEntry> messages) -> { throw new RuntimeException("Error"); }));
scribe.Client client = makeNiftyClient(getClientSSLConfiguration());
client.Log(Arrays.asList(new LogEntry("client", "aaa")));
}
@Test(expectedExceptions = TApplicationException.class,
expectedExceptionsMessageRegExp = "Internal error processing Log")
public void testClientAuthenticatingServerThrowsException() throws InterruptedException, IOException, TException {
SslServerConfiguration serverConfig = OpenSslServerConfiguration.newBuilder()
.certFile(getServerCertFile())
.keyFile(getPrivateKeyFile())
.allowPlaintext(false)
.sslVerification(OpenSslServerConfiguration.SSLVerification.VERIFY_REQUIRE)
.clientCAFile(getClientCertFile())
.build();
startServer(getThriftServerDefBuilder(
serverConfig,
null,
(List<LogEntry> messages) -> { throw new RuntimeException("Error"); }));
scribe.Client client = makeNiftyClient(getClientSSLConfiguration(null, getClientKeyManagers()));
client.Log(Arrays.asList(new LogEntry("client", "aaa")));
}
@Test
public void testSSLSessionResumption() throws Exception {
// Ticket resumes are not supported by nifty client, so we test stateful session resumption
// only.
SessionTicketKey[] keys = { createSessionTicketKey() };
SslServerConfiguration sslServerConfiguration = createSSLServerConfiguration(true, keys);
startServer(getThriftServerDefBuilder(sslServerConfiguration, null));
SslClientConfiguration sslClientConfiguration = getClientSSLConfiguration();
scribe.Client client1 = makeNiftyClient(sslClientConfiguration);
client1.Log(Arrays.asList(new LogEntry("client1", "aaa")));
Assert.assertFalse(isSessionResumed(getSSLSession(client1)));
scribe.Client client2 = makeNiftyClient(sslClientConfiguration);
client2.Log(Arrays.asList(new LogEntry("client2", "aaa")));
Assert.assertTrue(isSessionResumed(getSSLSession(client2)));
client2.Log(Arrays.asList(new LogEntry("client2", "bbb")));
Assert.assertTrue(isSessionResumed(getSSLSession(client2)));
SessionTicketKey[] keys2 = { createSessionTicketKey() };
SslServerConfiguration sslServerConfiguration2 = createSSLServerConfiguration(true, keys2);
server.updateSSLConfiguration(sslServerConfiguration2);
scribe.Client client3 = makeNiftyClient(sslClientConfiguration);
client3.Log(Arrays.asList(new LogEntry("client3", "aaa")));
Assert.assertFalse(isSessionResumed(getSSLSession(client3)));
scribe.Client client4 = makeNiftyClient(sslClientConfiguration);
client4.Log(Arrays.asList(new LogEntry("client4", "aaa")));
Assert.assertTrue(isSessionResumed(getSSLSession(client4)));
}
class TestConfigUpdater implements TransportAttachObserver {
public NettyServerTransport attachedTransport;
@Override
public void attachTransport(NettyServerTransport transport) {
attachedTransport = transport;
}
@Override
public void detachTransport() {
attachedTransport = null;
}
void updateSSLConfig(SslServerConfiguration newConfig) {
attachedTransport.updateSSLConfiguration(newConfig);
}
};
@Test
public void testAttachTransportToUpdater() throws InterruptedException, IOException {
TestConfigUpdater configUpdater = new TestConfigUpdater();
SessionTicketKey[] keys = { createSessionTicketKey() };
SslServerConfiguration sslServerConfiguration = createSSLServerConfiguration(true, keys);
startServer(getThriftServerDefBuilder(sslServerConfiguration, configUpdater));
Assert.assertNotNull(configUpdater.attachedTransport);
SessionTicketKey[] newKeys = { createSessionTicketKey() };
SslServerConfiguration newConfig = createSSLServerConfiguration(true, newKeys);
configUpdater.updateSSLConfig(newConfig);
server.stop();
server = null;
Assert.assertNull(configUpdater.attachedTransport);
}
@Test
public void testRotateTicketSeedFile() throws InterruptedException, IOException {
startServer();
OpenSslServerConfiguration config = (OpenSslServerConfiguration) server.getSSLConfiguration();
List<SessionTicketKey> actual = ImmutableList.copyOf(config.ticketKeys);
List<SessionTicketKey> expected = new TicketSeedFileParser().parse(getTicketSeedFile());
assertTicketKeysEqual(actual, expected);
// Rotate the ticket seeds file
long callbacksSucceeded = fileWatcher.getStats().getCallbacksSucceeded();
updateTicketSeedFile(getResourceFileContents(Plain.class, "/ticket_seeds2.json"));
while (fileWatcher.getStats().getCallbacksSucceeded() < callbacksSucceeded + 1) {
Thread.sleep(25);
}
config = (OpenSslServerConfiguration) server.getSSLConfiguration();
List<SessionTicketKey> actual2 = ImmutableList.copyOf(config.ticketKeys);
List<SessionTicketKey> expected2 = new TicketSeedFileParser().parse(getTicketSeedFile());
assertTicketKeysEqual(actual2, expected2);
// Make sure the keys actually changed ...
Assert.assertNotEquals(actual.get(0).getName(), actual2.get(0).getName());
}
@Test
public void testRotateSSLKeyAndCertFiles() throws InterruptedException, IOException, TException {
startServer();
// This client config is using the original cert that the server starts up with
SslClientConfiguration config1 = getClientSSLConfiguration(getResourceFile(Plain.class, "/rsa.crt"));
// This client config is using the cert that we change to halfway through this test
SslClientConfiguration config2 = getClientSSLConfiguration(getResourceFile(Plain.class, "/rsa2.crt"));
scribe.Client client1 = makeNiftyClient(config1);
scribe.Client client2 = makeNiftyClient(config2);
Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "aaa"))), ResultCode.OK);
// Before the server cert is rotated, using it on the client should fail
try {
client2.Log(Arrays.asList(new LogEntry("client2", "aaa")));
Assert.fail("Request with wrong certificate should have thrown an exception");
} catch (TTransportException e) {
// The error is expected
}
// Rotate the cert and private key files
long callbacksSucceeded = fileWatcher.getStats().getCallbacksSucceeded();
updateServerCertFile(getResourceFileContents(Plain.class, "/rsa2.crt"));
updatePrivateKeyFile(getResourceFileContents(Plain.class, "/rsa2.key"));
while (fileWatcher.getStats().getCallbacksSucceeded() < callbacksSucceeded + 1) {
Thread.sleep(25);
}
// Need to re-create clients to get their connections to use the new server cert.
client1 = makeNiftyClient(config1);
client2 = makeNiftyClient(config2);
// After the server cert is rotated, using the original cert on the client should fail
try {
client1.Log(Arrays.asList(new LogEntry("client1", "bbb")));
Assert.fail("Request with wrong certificate should have thrown an exception");
} catch (TTransportException e) {
// The error is expected
}
Assert.assertEquals(client2.Log(Arrays.asList(new LogEntry("client2", "bbb"))), ResultCode.OK);
}
private static SessionTicketKey createSessionTicketKey() {
SecureRandom secureRandom = new SecureRandom();
byte[] name = new byte[SessionTicketKey.NAME_SIZE];
byte[] hmac = new byte[SessionTicketKey.HMAC_KEY_SIZE];
byte[] aes = new byte[SessionTicketKey.AES_KEY_SIZE];
secureRandom.nextBytes(name);
secureRandom.nextBytes(hmac);
secureRandom.nextBytes(aes);
return new SessionTicketKey(name, hmac, aes);
}
private static SSLSession getSSLSession(scribe.Client client) {
TNiftyClientChannelTransport clientTransport =
(TNiftyClientChannelTransport) client.getInputProtocol().getTransport();
SslHandler sslHandler = (SslHandler) clientTransport.getChannel().getNettyChannel().getPipeline().get("ssl");
return sslHandler.getEngine().getSession();
}
private static boolean isSessionResumed(SSLSession sslSession) throws NoSuchFieldException, IllegalAccessException {
Field sslResumedField = sslSession.getClass().getDeclaredField("isSessionResumption");
sslResumedField.setAccessible(true);
return sslResumedField.getBoolean(sslSession);
}
}