// // ======================================================================== // Copyright (c) 1995-2014 Mort Bay Consulting Pty. Ltd. // ------------------------------------------------------------------------ // All rights reserved. This program and the accompanying materials // are made available under the terms of the Eclipse Public License v1.0 // and Apache License v2.0 which accompanies this distribution. // // The Eclipse Public License is available at // http://www.eclipse.org/legal/epl-v10.html // // The Apache License v2.0 is available at // http://www.opensource.org/licenses/apache2.0.php // // You may elect to redistribute this code under either of these licenses. // ======================================================================== // package org.mortbay.jetty.alpn; import java.util.Arrays; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLSession; import org.eclipse.jetty.alpn.ALPN; import org.junit.Assert; import org.junit.Before; import org.junit.Test; public abstract class AbstractALPNTest<T> { protected abstract SSLResult<T> performTLSHandshake(SSLResult<T> handshake, ALPN.ClientProvider clientProvider, ALPN.ServerProvider serverProvider) throws Exception; protected abstract void performTLSClose(SSLResult<T> sslResult) throws Exception; protected abstract void performDataExchange(SSLResult<T> sslResult) throws Exception; protected abstract void performTLSRenegotiation(SSLResult<T> sslResult, boolean client) throws Exception; protected abstract SSLSession getSSLSession(SSLResult<T> sslResult, boolean client) throws Exception; @Before public void prepare() throws Exception { Assert.assertNull("ALPN classes must be in the bootclasspath.", ALPN.class.getClassLoader()); ALPN.debug = true; } @Test public void testALPNSuccessful() throws Exception { final String protocolName = "test"; final CountDownLatch latch = new CountDownLatch(3); ALPN.ClientProvider clientProvider = new ALPN.ClientProvider() { @Override public List<String> protocols() { latch.countDown(); return Arrays.asList(protocolName); } @Override public void unsupported() { Assert.fail(); } @Override public void selected(String protocol) { Assert.assertEquals(protocolName, protocol); latch.countDown(); } }; ALPN.ServerProvider serverProvider = new ALPN.ServerProvider() { @Override public void unsupported() { Assert.fail(); } @Override public String select(List<String> protocols) { Assert.assertEquals(1, protocols.size()); String protocol = protocols.get(0); Assert.assertEquals(protocolName, protocol); latch.countDown(); return protocol; } }; SSLResult<T> sslResult = performTLSHandshake(null, clientProvider, serverProvider); Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); // Verify that we can exchange data without errors. performDataExchange(sslResult); performTLSClose(sslResult); } @Test public void testServerDoesNotSendALPN() throws Exception { final String protocolName = "test"; final CountDownLatch latch = new CountDownLatch(3); ALPN.ClientProvider clientProvider = new ALPN.ClientProvider() { @Override public List<String> protocols() { latch.countDown(); return Arrays.asList(protocolName); } @Override public void unsupported() { latch.countDown(); } @Override public void selected(String protocol) { Assert.fail(); } }; ALPN.ServerProvider serverProvider = new ALPN.ServerProvider() { @Override public void unsupported() { Assert.fail(); } @Override public String select(List<String> protocols) { Assert.assertEquals(1, protocols.size()); String protocol = protocols.get(0); Assert.assertEquals(protocolName, protocol); latch.countDown(); // By returning null, the server won't send the ALPN extension. return null; } }; SSLResult<T> sslResult = performTLSHandshake(null, clientProvider, serverProvider); Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); // Verify that we can exchange data without errors. performDataExchange(sslResult); performTLSClose(sslResult); } @Test public void testServerThrowsException() throws Exception { final String protocolName = "test"; ALPN.ClientProvider clientProvider = new ALPN.ClientProvider() { @Override public List<String> protocols() { return Arrays.asList(protocolName); } @Override public void unsupported() { Assert.fail(); } @Override public void selected(String protocol) { Assert.fail(); } }; ALPN.ServerProvider serverProvider = new ALPN.ServerProvider() { @Override public void unsupported() { Assert.fail(); } @Override public String select(List<String> protocols) throws SSLException { // By throwing, the server will close the connection. throw new SSLHandshakeException("explicitly_thrown_by_test"); } }; try { performTLSHandshake(null, clientProvider, serverProvider); Assert.fail(); } catch (SSLHandshakeException x) { // Expected. } } @Test public void testClientThrowsException() throws Exception { final String protocolName = "test"; ALPN.ClientProvider clientProvider = new ALPN.ClientProvider() { @Override public List<String> protocols() { return Arrays.asList(protocolName); } @Override public void unsupported() { Assert.fail(); } @Override public void selected(String protocol) throws SSLException { if (!protocolName.equals(protocol)) throw new SSLHandshakeException("explicitly_thrown_by_test"); } }; ALPN.ServerProvider serverProvider = new ALPN.ServerProvider() { @Override public void unsupported() { Assert.fail(); } @Override public String select(List<String> protocols) throws SSLException { // Return a protocol that the client does not support. return "boom." + protocolName; } }; try { performTLSHandshake(null, clientProvider, serverProvider); Assert.fail(); } catch (SSLHandshakeException x) { // Expected. } } @Test public void testClientTLSRenegotiation() throws Exception { testTLSRenegotiation(true); } @Test public void testServerTLSRenegotiation() throws Exception { testTLSRenegotiation(false); } private void testTLSRenegotiation(boolean client) throws Exception { final String protocolName = "test"; final AtomicReference<CountDownLatch> latch = new AtomicReference<>(new CountDownLatch(3)); ALPN.ClientProvider clientProvider = new ALPN.ClientProvider() { @Override public List<String> protocols() { latch.get().countDown(); return Arrays.asList(protocolName); } @Override public void unsupported() { latch.get().countDown(); Assert.fail(); } @Override public void selected(String protocol) { Assert.assertEquals(protocolName, protocol); latch.get().countDown(); } }; ALPN.ServerProvider serverProvider = new ALPN.ServerProvider() { @Override public void unsupported() { latch.get().countDown(); Assert.fail(); } @Override public String select(List<String> protocols) { Assert.assertEquals(1, protocols.size()); String protocol = protocols.get(0); Assert.assertEquals(protocolName, protocol); latch.get().countDown(); return protocol; } }; SSLResult<T> sslResult = performTLSHandshake(null, clientProvider, serverProvider); Assert.assertTrue(latch.get().await(5, TimeUnit.SECONDS)); // Verify that we can exchange data without errors. performDataExchange(sslResult); latch.set(new CountDownLatch(1)); performTLSRenegotiation(sslResult, client); // The data exchange may trigger the completion of the TLS renegotiation. performDataExchange(sslResult); // ALPN must not trigger. Assert.assertFalse(latch.get().await(1, TimeUnit.SECONDS)); performTLSClose(sslResult); } @Test public void testTLSSessionResumption() throws Exception { final String protocolName = "test"; final AtomicReference<CountDownLatch> latch = new AtomicReference<>(); ALPN.ClientProvider clientProvider = new ALPN.ClientProvider() { @Override public List<String> protocols() { latch.get().countDown(); return Arrays.asList(protocolName); } @Override public void unsupported() { Assert.fail(); } @Override public void selected(String protocol) { Assert.assertEquals(protocolName, protocol); latch.get().countDown(); } }; ALPN.ServerProvider serverProvider = new ALPN.ServerProvider() { @Override public void unsupported() { Assert.fail(); } @Override public String select(List<String> protocols) { Assert.assertEquals(1, protocols.size()); String protocol = protocols.get(0); Assert.assertEquals(protocolName, protocol); latch.get().countDown(); return protocol; } }; // First TLS handshake. latch.set(new CountDownLatch(3)); SSLResult<T> sslResult = performTLSHandshake(null, clientProvider, serverProvider); Assert.assertTrue(latch.get().await(5, TimeUnit.SECONDS)); SSLSession clientSession1 = getSSLSession(sslResult, true); SSLSession serverSession1 = getSSLSession(sslResult, false); // Must close the first session before starting the second one. performTLSClose(sslResult); // Second TLS handshake. latch.set(new CountDownLatch(3)); sslResult = performTLSHandshake(sslResult, clientProvider, serverProvider); Assert.assertTrue(latch.get().await(5, TimeUnit.SECONDS)); Assert.assertSame(clientSession1, getSSLSession(sslResult, true)); Assert.assertSame(serverSession1, getSSLSession(sslResult, false)); performTLSClose(sslResult); } public static class SSLResult<S> { public SSLContext context; public S client; public S server; } }