package org.eclipse.jetty.server.handler; import java.io.BufferedReader; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStream; import java.net.Socket; import java.security.SecureRandom; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManager; import javax.net.ssl.X509TrustManager; import javax.servlet.ServletException; import javax.servlet.ServletOutputStream; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.eclipse.jetty.server.Request; import org.eclipse.jetty.server.ssl.SslSelectChannelConnector; import org.eclipse.jetty.toolchain.test.MavenTestingUtils; import org.eclipse.jetty.util.ssl.SslContextFactory; import org.junit.BeforeClass; import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; /** * @version $Revision$ $Date$ */ public class ConnectHandlerSSLTest extends AbstractConnectHandlerTest { @BeforeClass public static void init() throws Exception { SslSelectChannelConnector connector = new SslSelectChannelConnector(); connector.setMaxIdleTime(3600000); // TODO remove String keyStorePath = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath(); SslContextFactory cf = connector.getSslContextFactory(); cf.setKeyStorePath(keyStorePath); cf.setKeyStorePassword("storepwd"); cf.setKeyManagerPassword("keypwd"); startServer(connector, new ServerHandler()); startProxy(); } @Test public void testGETRequest() throws Exception { String hostPort = "localhost:" + serverConnector.getLocalPort(); String request = "" + "CONNECT " + hostPort + " HTTP/1.1\r\n" + "Host: " + hostPort + "\r\n" + "\r\n"; Socket socket = newSocket(); socket.setSoTimeout(3600000); // TODO remove try { OutputStream output = socket.getOutputStream(); BufferedReader input = new BufferedReader(new InputStreamReader(socket.getInputStream())); output.write(request.getBytes("UTF-8")); output.flush(); // Expect 200 OK from the CONNECT request Response response = readResponse(input); System.err.println(response); assertEquals("200", response.getCode()); // Be sure the buffered input does not have anything buffered assertFalse(input.ready()); // Upgrade the socket to SSL SSLSocket sslSocket = wrapSocket(socket); try { output = sslSocket.getOutputStream(); input = new BufferedReader(new InputStreamReader(sslSocket.getInputStream())); request = "GET /echo HTTP/1.1\r\n" + "Host: " + hostPort + "\r\n" + "\r\n"; output.write(request.getBytes("UTF-8")); output.flush(); response = readResponse(input); assertEquals("200", response.getCode()); assertEquals("GET /echo", response.getBody()); } finally { sslSocket.close(); } } finally { socket.close(); } } @Test public void testPOSTRequests() throws Exception { String hostPort = "localhost:" + serverConnector.getLocalPort(); String request = "" + "CONNECT " + hostPort + " HTTP/1.1\r\n" + "Host: " + hostPort + "\r\n" + "\r\n"; Socket socket = newSocket(); try { OutputStream output = socket.getOutputStream(); BufferedReader input = new BufferedReader(new InputStreamReader(socket.getInputStream())); output.write(request.getBytes("UTF-8")); output.flush(); // Expect 200 OK from the CONNECT request Response response = readResponse(input); assertEquals("200", response.getCode()); // Be sure the buffered input does not have anything buffered assertFalse(input.ready()); // Upgrade the socket to SSL SSLSocket sslSocket = wrapSocket(socket); try { output = sslSocket.getOutputStream(); input = new BufferedReader(new InputStreamReader(sslSocket.getInputStream())); for (int i = 0; i < 10; ++i) { request = "" + "POST /echo?param=" + i + " HTTP/1.1\r\n" + "Host: " + hostPort + "\r\n" + "Content-Length: 5\r\n" + "\r\n" + "HELLO"; output.write(request.getBytes("UTF-8")); output.flush(); response = readResponse(input); assertEquals("200", response.getCode()); assertEquals("POST /echo?param=" + i + "\r\nHELLO", response.getBody()); } } finally { sslSocket.close(); } } finally { socket.close(); } } private SSLSocket wrapSocket(Socket socket) throws Exception { SSLContext sslContext = SSLContext.getInstance("SSLv3"); sslContext.init(null, new TrustManager[]{new AlwaysTrustManager()}, new SecureRandom()); SSLSocketFactory socketFactory = sslContext.getSocketFactory(); SSLSocket sslSocket = (SSLSocket)socketFactory.createSocket(socket, socket.getInetAddress().getHostAddress(), socket.getPort(), true); sslSocket.setUseClientMode(true); sslSocket.startHandshake(); return sslSocket; } private class AlwaysTrustManager implements X509TrustManager { public void checkClientTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { } public void checkServerTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { } public X509Certificate[] getAcceptedIssuers() { return new X509Certificate[]{}; } } private static class ServerHandler extends AbstractHandler { public void handle(String target, Request request, HttpServletRequest httpRequest, HttpServletResponse httpResponse) throws IOException, ServletException { request.setHandled(true); String uri = httpRequest.getRequestURI(); if ("/echo".equals(uri)) { StringBuilder builder = new StringBuilder(); builder.append(httpRequest.getMethod()).append(" ").append(uri); if (httpRequest.getQueryString() != null) builder.append("?").append(httpRequest.getQueryString()); ByteArrayOutputStream baos = new ByteArrayOutputStream(); InputStream input = httpRequest.getInputStream(); int read = -1; while ((read = input.read()) >= 0) baos.write(read); baos.close(); ServletOutputStream output = httpResponse.getOutputStream(); output.println(builder.toString()); output.write(baos.toByteArray()); } else { throw new ServletException(); } } } }