/* * Copyright (c) MuleSoft, Inc. All rights reserved. http://www.mulesoft.com * The software in this package is published under the terms of the CPAL v1.0 * license, a copy of which has been included with this distribution in the * LICENSE.txt file. */ package org.mule.runtime.core.api.security.tls; import static org.mule.runtime.api.i18n.I18nMessageFactory.createStaticMessage; import org.mule.runtime.api.exception.MuleRuntimeException; import org.mule.runtime.core.util.ArrayUtils; import java.io.IOException; import java.net.InetAddress; import java.net.Socket; import javax.net.SocketFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; /** * SSLSocketFactory decorator that restricts the available protocols and cipher suites in the sockets that are created. */ public class RestrictedSSLSocketFactory extends SSLSocketFactory { private final SSLSocketFactory sslSocketFactory; private final String[] enabledCipherSuites; private final String[] enabledProtocols; private final String[] defaultCipherSuites; private static RestrictedSSLSocketFactory defaultSocketFactory = null; public RestrictedSSLSocketFactory(SSLContext sslContext, String[] cipherSuites, String[] protocols) { this.sslSocketFactory = sslContext.getSocketFactory(); if (cipherSuites == null) { cipherSuites = sslSocketFactory.getDefaultCipherSuites(); } this.enabledCipherSuites = ArrayUtils.intersection(cipherSuites, sslSocketFactory.getSupportedCipherSuites()); this.defaultCipherSuites = ArrayUtils.intersection(cipherSuites, sslSocketFactory.getDefaultCipherSuites()); if (protocols == null) { protocols = sslContext.getDefaultSSLParameters().getProtocols(); } this.enabledProtocols = ArrayUtils.intersection(protocols, sslContext.getSupportedSSLParameters().getProtocols()); } @Override public Socket createSocket(String host, int port) throws IOException { return restrictCipherSuites((SSLSocket) sslSocketFactory.createSocket(host, port)); } @Override public Socket createSocket(String host, int port, InetAddress clientAddress, int clientPort) throws IOException { return restrictCipherSuites((SSLSocket) sslSocketFactory.createSocket(host, port, clientAddress, clientPort)); } @Override public Socket createSocket(InetAddress address, int port) throws IOException { return restrictCipherSuites((SSLSocket) sslSocketFactory.createSocket(address, port)); } @Override public Socket createSocket(InetAddress address, int port, InetAddress clientAddress, int clientPort) throws IOException { return restrictCipherSuites((SSLSocket) sslSocketFactory.createSocket(address, port, clientAddress, clientPort)); } @Override public String[] getDefaultCipherSuites() { return defaultCipherSuites; } @Override public String[] getSupportedCipherSuites() { return enabledCipherSuites; } @Override public Socket createSocket(Socket socket, String host, int port, boolean autoClose) throws IOException { return restrictCipherSuites((SSLSocket) sslSocketFactory.createSocket(socket, host, port, autoClose)); } @Override public Socket createSocket() throws IOException { return restrictCipherSuites((SSLSocket) sslSocketFactory.createSocket()); } private SSLSocket restrictCipherSuites(SSLSocket socket) { socket.setEnabledCipherSuites(enabledCipherSuites); socket.setEnabledProtocols(enabledProtocols); return socket; } public static synchronized SocketFactory getDefault() { if (defaultSocketFactory == null) { try { TlsConfiguration configuration = new TlsConfiguration(null); configuration.initialise(true, null); defaultSocketFactory = new RestrictedSSLSocketFactory(configuration.getSslContext(), configuration.getEnabledCipherSuites(), configuration.getEnabledProtocols()); } catch (Exception e) { throw new MuleRuntimeException(createStaticMessage("Could not create the default RestrictedSSLSocketFactory"), e); } } return defaultSocketFactory; } }