package sslnpn;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assume.assumeTrue;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.security.KeyManagementException;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSocketFactory;
import org.junit.Before;
import org.junit.Test;
import sslnpn.ssl.NextProtocolNegotiationChooser;
import sslnpn.ssl.SSLEngineImpl;
import sslnpn.ssl.SSLServerSocketImpl;
import sslnpn.ssl.SSLSocketImpl;
public class OpenSSLCompatabilityTest {
private File sessionFile = null;
private static final long TIMEOUT = 5000;
private int nextPort = 8443;
private SSLContext context;
private static final boolean DEBUG = true;
@Before
public void before() throws Exception {
assumeTrue(hasOpensslWithNextProtocolNegotiation());
if (DEBUG) {
debug();
}
sessionFile = File.createTempFile("sess", "sess");
}
private boolean hasOpensslWithNextProtocolNegotiation() throws Exception {
try {
Process p = new ProcessBuilder("openssl", "s_client", "--help").redirectErrorStream(true).start();
String output = drain(p, "TEST>");
p.destroy();
return output.contains("nextprotoneg");
} catch (IOException e) {
return false;
}
}
private void debug() {
System.setProperty("javax.net.debug", "all");
}
public static class SpawnSSLClient implements Callable<String> {
private File sessionFile;
private boolean newSession;
private int port;
public SpawnSSLClient(int port, File sessionFile, boolean newSession) {
this.sessionFile = sessionFile;
this.newSession = newSession;
this.port = port;
}
@Override
public String call() throws Exception {
List<String> command = new ArrayList<String>(Arrays.<String> asList("openssl", "s_client", "-nextprotoneg",
"http/1.0,spdy/2", "-host", "localhost", "-port", "" + port, "-sess_out",
sessionFile.getAbsolutePath()));
if (!newSession) {
if (DEBUG)
System.out.println("C>Starting Session");
command.add("-sess_in");
command.add(sessionFile.getAbsolutePath());
}
Process process = new ProcessBuilder(command).redirectErrorStream(true).start();
return drain(process, "C>");
}
}
public static String drain(Process process, String prefix) throws IOException {
try {
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(process.getInputStream()));
try {
String line;
StringBuilder builder = new StringBuilder();
try {
while ((line = bufferedReader.readLine()) != null) {
if (OpenSSLCompatabilityTest.DEBUG) {
System.out.println(prefix + line);
}
builder.append(line);
builder.append("\n");
}
} catch (IOException e) {
// ignore
}
return builder.toString();
} finally {
bufferedReader.close();
}
} finally {
process.destroy();
}
}
public static class Chooser implements NextProtocolNegotiationChooser {
@Override
public String chooseProtocol(List<String> protocols) {
return "http/1.1";
}
}
private void handleClientSocket(SSLServerSocketImpl serverSocket, boolean checkSpdy, boolean newSession)
throws Exception {
Future<String> sslOutput = Executors.newSingleThreadExecutor().submit(
new SpawnSSLClient(serverSocket.getLocalPort(), this.sessionFile, newSession));
SSLSocketImpl socket = (SSLSocketImpl) serverSocket.accept();
try {
socket.startHandshake();
socket.getOutputStream().write("helloworld\n".getBytes("UTF-8"));
} finally {
socket.close();
}
String output = sslOutput.get(2, TimeUnit.SECONDS);
if (checkSpdy) {
assertEquals("spdy/2", socket.getNegotiatedNextProtocol());
assertTrue(output.contains("Next protocol: (1) spdy/2"));
} else {
assertEquals(null, socket.getNegotiatedNextProtocol());
}
assertTrue(output.contains("helloworld"));
}
@Test(timeout = TIMEOUT)
public void testCreateSSLServerSocketWithOpenSSLAndNoSpdy() throws Exception {
SSLContext context = createContext();
SSLServerSocketFactory factory = context.getServerSocketFactory();
SSLServerSocketImpl serverSocket = (SSLServerSocketImpl) factory.createServerSocket(0);
serverSocket.setReuseAddress(true);
try {
handleClientSocket(serverSocket, false, true);
handleClientSocket(serverSocket, false, false);
} finally {
serverSocket.close();
}
}
@Test(timeout = TIMEOUT)
public void testCreateSSLServerSocketWithOpenSSLAndResumeAndSpdy() throws Exception {
SSLContext context = createContext();
SSLServerSocketFactory factory = context.getServerSocketFactory();
SSLServerSocketImpl serverSocket = (SSLServerSocketImpl) factory.createServerSocket(0);
serverSocket.setReuseAddress(true);
serverSocket.setAdvertisedNextProtocols("http/1.1", "spdy/2");
try {
handleClientSocket(serverSocket, true, true);
handleClientSocket(serverSocket, true, false);
} finally {
serverSocket.close();
}
}
@Test(timeout = TIMEOUT)
public void testCreateSSLServerSocketWithOpenSSL() throws Exception {
SSLContext context = createContext();
SSLServerSocketFactory factory = context.getServerSocketFactory();
SSLServerSocketImpl serverSocket = (SSLServerSocketImpl) factory.createServerSocket(0);
serverSocket.setReuseAddress(true);
((SSLServerSocketImpl) serverSocket).setAdvertisedNextProtocols("http/1.1", "spdy/2");
try {
handleClientSocket(serverSocket, true, true);
} finally {
serverSocket.close();
}
}
@Test(timeout = TIMEOUT)
public void testCreateSSLClientEngineWithOpenSSL() throws Exception {
SSLEngineImpl engine;
try (ProcessAndPort processAndPort = spawnOpensslServer()) {
Future<String> sslOutput = Executors.newSingleThreadExecutor().submit(
new DrainOutput(processAndPort.process));
SSLContext context = createContext();
SocketChannel socket = SocketChannel.open();
try {
socket.connect(new InetSocketAddress(InetAddress.getLoopbackAddress(), processAndPort.port));
engine = (SSLEngineImpl) context.createSSLEngine();
engine.setNpnChooser(new Chooser());
engine.setUseClientMode(true);
SSLEngineHandshaker.negotiateHandshake(engine, socket);
} finally {
socket.close();
}
processAndPort.process.destroy();
assertEquals("http/1.1", engine.getNegotiatedNextProtocol());
assertTrue(sslOutput.get().contains("http/1.1"));
}
}
@Test(timeout = TIMEOUT)
public void testCreateSSLServerEngineWithOpenSSL() throws Exception {
SSLContext context = createContext();
ServerSocketChannel serverSocket = ServerSocketChannel.open();
serverSocket.bind(new InetSocketAddress(0));
int port = ((InetSocketAddress) serverSocket.getLocalAddress()).getPort();
try {
SSLEngineImpl engine = (SSLEngineImpl) context.createSSLEngine();
engine.setAdvertisedNextProtocols("http/1.1", "spdy/2");
Future<String> sslOutput = Executors.newSingleThreadExecutor().submit(
new SpawnSSLClient(port, this.sessionFile, true));
SocketChannel socket = serverSocket.accept();
try {
engine.setUseClientMode(false);
SSLEngineHandshaker.negotiateHandshake(engine, socket);
assertEquals("spdy/2", engine.getNegotiatedNextProtocol());
} finally {
socket.close();
}
assertTrue(sslOutput.get().contains("spdy/2"));
} finally {
serverSocket.close();
}
}
public static class DrainOutput implements Callable<String> {
private Process process;
public DrainOutput(Process is) {
this.process = is;
}
@Override
public String call() throws Exception {
String output = drain(process, "S>");
return output;
}
}
static class ProcessAndPort implements AutoCloseable {
private Process process;
private int port;
public ProcessAndPort(Process process, int port) {
this.port = port;
this.process = process;
}
@Override
public void close() throws Exception {
this.process.destroy();
}
}
private ProcessAndPort spawnOpensslServer() throws IOException, InterruptedException {
int port = nextPort;
++nextPort;
Process process = new ProcessBuilder("openssl", "s_server", "-nextprotoneg", "http/1.1,spdy/2", "-debug",
"-msg", "-port", "" + port, "-key", "server.key", "-cert", "server.crt").redirectErrorStream(true)
.start();
Thread.sleep(1000);
return new ProcessAndPort(process, port);
}
@Test(timeout = TIMEOUT)
public void testCreateSSLSocketWithOpensslServerAndNoNpn() throws Exception {
try (ProcessAndPort sslServer = spawnOpensslServer()) {
Future<String> sslOutput = Executors.newSingleThreadExecutor().submit(new DrainOutput(sslServer.process));
SSLContext context = createContext();
SSLSocketFactory factory = context.getSocketFactory();
SSLSocketImpl socket = (SSLSocketImpl) factory.createSocket();
try {
socket.connect(new InetSocketAddress("localhost", sslServer.port));
socket.startHandshake();
socket.getOutputStream().write("helloworld".getBytes("UTF-8"));
assertEquals(null, socket.getNegotiatedNextProtocol());
} finally {
socket.close();
}
Thread.sleep(1000);
sslServer.process.destroy();
String output = sslOutput.get(2, TimeUnit.SECONDS);
assertTrue("Failed to receive helloworld in server output: " + output, output.contains("helloworld"));
}
}
@Test(timeout = TIMEOUT)
public void testCreateSSLSocketWithOpensslServerWithResumption() throws Exception {
try (ProcessAndPort sslServer = spawnOpensslServer()) {
Future<String> sslOutput = Executors.newSingleThreadExecutor().submit(new DrainOutput(sslServer.process));
SSLSocketImpl socket1 = connectToSSLServer(sslServer.port);
SSLSocketImpl socket2 = connectToSSLServer(sslServer.port);
assertTrue(Arrays.equals(socket1.getSession().getId(), socket2.getSession().getId()));
sslServer.process.destroy();
String output = sslOutput.get(2, TimeUnit.SECONDS);
assertTrue("Failed to receive http/1.1 in server output: " + output,
output.contains("NEXTPROTO is http/1.1"));
}
}
@Test(timeout = TIMEOUT)
public void testCreateSSLSocketWithOpensslServer() throws Exception {
try (ProcessAndPort sslServer = spawnOpensslServer()) {
Future<String> sslOutput = Executors.newSingleThreadExecutor().submit(new DrainOutput(sslServer.process));
SSLSocketImpl socket1 = connectToSSLServer(sslServer.port);
sslServer.process.destroy();
String output = sslOutput.get(2, TimeUnit.SECONDS);
assertTrue("Failed to receive http/1.1 in server output: " + output,
output.contains("NEXTPROTO is http/1.1"));
}
}
private SSLSocketImpl connectToSSLServer(int port) throws Exception {
SSLContext context = createContext();
SSLSocketFactory factory = context.getSocketFactory();
SSLSocketImpl socket = (SSLSocketImpl) factory.createSocket();
try {
socket.setNpnChooser(new Chooser());
socket.connect(new InetSocketAddress("localhost", port));
socket.startHandshake();
socket.getOutputStream().write("hello".getBytes("UTF-8"));
assertEquals("http/1.1", socket.getNegotiatedNextProtocol());
} finally {
socket.close();
}
Thread.sleep(500);
return socket;
}
private SSLContext createContext() throws KeyManagementException, UnrecoverableKeyException,
NoSuchAlgorithmException, KeyStoreException, CertificateException, FileNotFoundException, IOException {
if (context == null) {
context = SSLContextCreator.newContext();
}
return context;
}
}