/* * Copyright (C) 2012-2016 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.nifty.server; import com.facebook.nifty.client.FramedClientConnector; import com.facebook.nifty.client.NettyClientConfig; import com.facebook.nifty.client.NiftyClient; import com.facebook.nifty.client.TNiftyClientChannelTransport; import com.facebook.nifty.core.*; import com.facebook.nifty.ssl.JavaSslServerConfiguration; import com.facebook.nifty.ssl.SslClientConfiguration; import com.facebook.nifty.ssl.SslServerConfiguration; import com.facebook.nifty.test.LogEntry; import com.facebook.nifty.test.ResultCode; import com.facebook.nifty.test.scribe; import io.airlift.log.Logger; import org.apache.thrift.TException; import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.protocol.TProtocol; import org.apache.thrift.transport.TTransport; import org.apache.thrift.transport.TTransportException; import org.apache.tomcat.jni.SessionTicketKey; import org.jboss.netty.channel.group.DefaultChannelGroup; import org.jboss.netty.handler.ssl.SslHandler; import org.testng.Assert; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; import java.io.File; import javax.net.ssl.SSLSession; import java.lang.reflect.Field; import java.net.InetSocketAddress; import java.security.SecureRandom; import java.util.Arrays; import java.util.List; public class TestNiftyJavaSslServer { private static final Logger log = Logger.get(TestNiftyJavaSslServer.class); private NettyServerTransport server; private int port; @BeforeMethod(alwaysRun = true) public void setup() { server = null; } @AfterMethod(alwaysRun = true) public void teardown() throws InterruptedException { if (server != null) { server.stop(); } } private void startServer() { startServer(getThriftServerDefBuilder(createSSLServerConfiguration(false, null))); } private void startServer(final ThriftServerDefBuilder thriftServerDefBuilder) { server = new NettyServerTransport(thriftServerDefBuilder.build(), NettyServerConfig.newBuilder().build(), new DefaultChannelGroup()); server.start(); port = ((InetSocketAddress)server.getServerChannel().getLocalAddress()).getPort(); } SslServerConfiguration createSSLServerConfiguration(boolean allowPlaintext, SessionTicketKey[] ticketKeys) { return JavaSslServerConfiguration.newBuilder() .certFile(new File(Plain.class.getResource("/rsa.crt").getFile())) .keyFile(new File(Plain.class.getResource("/rsa.key").getFile())) .allowPlaintext(allowPlaintext) .build(); } private ThriftServerDefBuilder getThriftServerDefBuilder(SslServerConfiguration sslServerConfiguration) { return new ThriftServerDefBuilder() .listen(0) .withSSLConfiguration(sslServerConfiguration) .withProcessor(new scribe.Processor<>(new scribe.Iface() { @Override public ResultCode Log(List<LogEntry> messages) throws TException { RequestContext context = RequestContexts.getCurrentContext(); for (LogEntry message : messages) { log.info("[Client: %s] %s: %s", context.getConnectionContext().getRemoteAddress(), message.getCategory(), message.getMessage()); } return ResultCode.OK; } })); } private static SslClientConfiguration getClientSSLConfiguration() { return new SslClientConfiguration.Builder() .caFile(new File(Plain.class.getResource("/rsa.crt").getFile())) .sessionCacheSize(10000) .sessionTimeoutSeconds(10000) .build(); } private scribe.Client makeNiftyClient(SslClientConfiguration clientSSLConfiguration) throws TTransportException, InterruptedException { NettyClientConfig config = NettyClientConfig.newBuilder() .setSSLClientConfiguration(clientSSLConfiguration).build(); InetSocketAddress address = new InetSocketAddress("localhost", port); TTransport transport = new NiftyClient(config) .connectSync(scribe.Client.class, new FramedClientConnector(address)); TProtocol protocol = new TBinaryProtocol(transport); return new scribe.Client(protocol); } private scribe.Client makeNiftyPlaintextClient() throws TTransportException, InterruptedException { NettyClientConfig config = NettyClientConfig.newBuilder().build(); InetSocketAddress address = new InetSocketAddress("localhost", port); TTransport transport = new NiftyClient(config) .connectSync(scribe.Client.class, new FramedClientConnector(address)); TProtocol protocol = new TBinaryProtocol(transport); return new scribe.Client(protocol); } @Test public void testSSL() throws InterruptedException, TException { startServer(); scribe.Client client1 = makeNiftyClient(getClientSSLConfiguration()); Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "aaa"))), ResultCode.OK); Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "bbb"))), ResultCode.OK); scribe.Client client2 = makeNiftyClient(getClientSSLConfiguration()); Assert.assertEquals(client2.Log(Arrays.asList(new LogEntry("client2", "ccc"))), ResultCode.OK); } @Test public void testSSLWithPlaintextAllowedServer() throws InterruptedException, TException { startServer(getThriftServerDefBuilder(createSSLServerConfiguration(true, null))); scribe.Client client1 = makeNiftyClient(getClientSSLConfiguration()); Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "aaa"))), ResultCode.OK); Assert.assertEquals(client1.Log(Arrays.asList(new LogEntry("client1", "bbb"))), ResultCode.OK); scribe.Client client2 = makeNiftyClient(getClientSSLConfiguration()); Assert.assertEquals(client2.Log(Arrays.asList(new LogEntry("client2", "ccc"))), ResultCode.OK); } @Test(expectedExceptions = TTransportException.class) public void testUnencryptedClient() throws InterruptedException, TException { startServer(); scribe.Client client = makeNiftyPlaintextClient(); client.Log(Arrays.asList(new LogEntry("client2", "aaa"))); client.Log(Arrays.asList(new LogEntry("client2", "bbb"))); client.Log(Arrays.asList(new LogEntry("client2", "ccc"))); } @Test public void testUnencryptedClientWithAllowPlaintextServer() throws InterruptedException, TException { startServer(getThriftServerDefBuilder(createSSLServerConfiguration(true, null))); scribe.Client client = makeNiftyPlaintextClient(); client.Log(Arrays.asList(new LogEntry("client2", "aaa"))); client.Log(Arrays.asList(new LogEntry("client2", "bbb"))); client.Log(Arrays.asList(new LogEntry("client2", "ccc"))); } @Test public void testSSLSessionResumption() throws Exception { // Ticket resumes are not supported by nifty client, so we test stateful session resumption // only. SessionTicketKey[] keys = { createSessionTicketKey() }; SslServerConfiguration sslServerConfiguration = createSSLServerConfiguration(true, keys); startServer(getThriftServerDefBuilder(sslServerConfiguration)); SslClientConfiguration sslClientConfiguration = getClientSSLConfiguration(); scribe.Client client1 = makeNiftyClient(sslClientConfiguration); client1.Log(Arrays.asList(new LogEntry("client1", "aaa"))); Assert.assertFalse(isSessionResumed(getSSLSession(client1))); scribe.Client client2 = makeNiftyClient(sslClientConfiguration); client2.Log(Arrays.asList(new LogEntry("client2", "aaa"))); Assert.assertTrue(isSessionResumed(getSSLSession(client2))); client2.Log(Arrays.asList(new LogEntry("client2", "bbb"))); Assert.assertTrue(isSessionResumed(getSSLSession(client2))); SessionTicketKey[] keys2 = { createSessionTicketKey() }; SslServerConfiguration sslServerConfiguration2 = createSSLServerConfiguration(true, keys2); server.updateSSLConfiguration(sslServerConfiguration2); scribe.Client client3 = makeNiftyClient(sslClientConfiguration); client3.Log(Arrays.asList(new LogEntry("client3", "aaa"))); Assert.assertFalse(isSessionResumed(getSSLSession(client3))); scribe.Client client4 = makeNiftyClient(sslClientConfiguration); client4.Log(Arrays.asList(new LogEntry("client4", "aaa"))); Assert.assertTrue(isSessionResumed(getSSLSession(client4))); } private static SessionTicketKey createSessionTicketKey() { SecureRandom secureRandom = new SecureRandom(); byte[] name = new byte[SessionTicketKey.NAME_SIZE]; byte[] hmac = new byte[SessionTicketKey.HMAC_KEY_SIZE]; byte[] aes = new byte[SessionTicketKey.AES_KEY_SIZE]; secureRandom.nextBytes(name); secureRandom.nextBytes(hmac); secureRandom.nextBytes(aes); return new SessionTicketKey(name, hmac, aes); } private static SSLSession getSSLSession(scribe.Client client) { TNiftyClientChannelTransport clientTransport = (TNiftyClientChannelTransport) client.getInputProtocol().getTransport(); SslHandler sslHandler = (SslHandler) clientTransport.getChannel().getNettyChannel().getPipeline().get("ssl"); return sslHandler.getEngine().getSession(); } private static boolean isSessionResumed(SSLSession sslSession) throws NoSuchFieldException, IllegalAccessException { Field sslResumedField = sslSession.getClass().getDeclaredField("isSessionResumption"); sslResumedField.setAccessible(true); return sslResumedField.getBoolean(sslSession); } }