/** * Copyright 2005-2014 Restlet * * The contents of this file are subject to the terms of one of the following * open source licenses: Apache 2.0 or or EPL 1.0 (the "Licenses"). You can * select the license that you prefer but you may not use this file except in * compliance with one of these Licenses. * * You can obtain a copy of the Apache 2.0 license at * http://www.opensource.org/licenses/apache-2.0 * * You can obtain a copy of the EPL 1.0 license at * http://www.opensource.org/licenses/eclipse-1.0 * * See the Licenses for the specific language governing permissions and * limitations under the Licenses. * * Alternatively, you can obtain a royalty free commercial license with less * limitations, transferable or non-transferable, directly at * http://restlet.com/products/restlet-framework * * Restlet is a registered trademark of Restlet S.A.S. */ package org.restlet.engine.ssl; import java.io.IOException; import java.net.InetAddress; import java.net.Socket; import java.net.UnknownHostException; import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; /** * SSL socket factory that wraps the default one to do extra initialization. * Configures the cipher suites and the SSL certificate request. * * @author Jerome Louvel */ public class WrapperSslSocketFactory extends SSLSocketFactory { /** The parent SSL context factory. */ private final DefaultSslContextFactory contextFactory; /** The wrapped SSL server socket factory. */ private final SSLSocketFactory wrappedSocketFactory; /** * Constructor. * * @param contextFactory * The parent SSL context factory. * @param wrappedSocketFactory * The wrapped SSL server socket factory. */ public WrapperSslSocketFactory(DefaultSslContextFactory contextFactory, SSLSocketFactory wrappedSocketFactory) { this.wrappedSocketFactory = wrappedSocketFactory; this.contextFactory = contextFactory; } @Override public Socket createSocket() throws IOException { SSLSocket result = (SSLSocket) getWrappedSocketFactory().createSocket(); return initSslSocket(result); } @Override public Socket createSocket(InetAddress host, int port) throws IOException { SSLSocket result = (SSLSocket) getWrappedSocketFactory().createSocket( host, port); return initSslSocket(result); } @Override public Socket createSocket(InetAddress host, int port, InetAddress localAddress, int localPort) throws IOException { SSLSocket result = (SSLSocket) getWrappedSocketFactory().createSocket( host, port, localAddress, localPort); return initSslSocket(result); } @Override public Socket createSocket(Socket s, String host, int port, boolean autoClose) throws IOException { SSLSocket result = (SSLSocket) getWrappedSocketFactory().createSocket( s, host, port, autoClose); return initSslSocket(result); } @Override public Socket createSocket(String host, int port) throws IOException, UnknownHostException { SSLSocket result = (SSLSocket) getWrappedSocketFactory().createSocket( host, port); return initSslSocket(result); } @Override public Socket createSocket(String host, int port, InetAddress localAddress, int localPort) throws IOException, UnknownHostException { SSLSocket result = (SSLSocket) getWrappedSocketFactory().createSocket( host, port, localAddress, localPort); return initSslSocket(result); } /** * Returns the parent SSL context factory. * * @return The parent SSL context factory. */ public DefaultSslContextFactory getContextFactory() { return contextFactory; } @Override public String[] getDefaultCipherSuites() { return getWrappedSocketFactory().getDefaultCipherSuites(); } @Override public String[] getSupportedCipherSuites() { return getWrappedSocketFactory().getSupportedCipherSuites(); } /** * Returns the wrapped SSL socket factory. * * @return The wrapped SSL socket factory. */ public SSLSocketFactory getWrappedSocketFactory() { return wrappedSocketFactory; } /** * Initializes the SSL socket. Configures the certificate request (need or * want) and the enabled cipher suites. * * @param sslSocket * The socket to initialize. * @return The initialized socket. */ protected SSLSocket initSslSocket(SSLSocket sslSocket) { if (getContextFactory().isNeedClientAuthentication()) { sslSocket.setNeedClientAuth(true); } else if (getContextFactory().isWantClientAuthentication()) { sslSocket.setWantClientAuth(true); } if ((getContextFactory().getEnabledCipherSuites() != null) || (getContextFactory().getDisabledCipherSuites() != null)) { sslSocket.setEnabledCipherSuites(getContextFactory() .getSelectedCipherSuites( sslSocket.getSupportedCipherSuites())); } if ((getContextFactory().getEnabledProtocols() != null) || (getContextFactory().getDisabledProtocols() != null)) { sslSocket .setEnabledProtocols(getContextFactory() .getSelectedSslProtocols( sslSocket.getSupportedProtocols())); } return sslSocket; } }