/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package net.jini.jeri.ssl; import com.sun.jini.action.GetLongAction; import com.sun.jini.logging.Levels; import com.sun.jini.logging.LogUtil; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.Socket; import java.net.SocketAddress; import java.net.SocketException; import java.net.SocketTimeoutException; import java.net.UnknownHostException; import java.nio.channels.SocketChannel; import java.util.Collection; import java.util.logging.Level; import java.util.logging.Logger; import javax.net.SocketFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLException; import javax.net.ssl.SSLProtocolException; import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; import javax.security.auth.x500.X500Principal; import net.jini.core.constraint.InvocationConstraints; import net.jini.io.UnsupportedConstraintException; import net.jini.jeri.connection.Connection; import net.jini.jeri.connection.OutboundRequestHandle; import net.jini.security.Security; /** * Implementation of Connection used by SslEndpoint. * * @author Sun Microsystems, Inc. */ class SslConnection extends Utilities implements Connection { /* -- Fields -- */ /** * The maximum time a client session should be used before expiring -- * non-final to facilitate testing. Use 23.5 hours as the default to allow * the client to negotiate a new session before the server timeout, which * defaults to 24 hours. */ private static long maxClientSessionDuration = ((Long) Security.doPrivileged( new GetLongAction("com.sun.jini.jeri.ssl.maxClientSessionDuration", (long) (23.5 * 60 * 60 * 1000)))).longValue(); /** Client logger */ private static final Logger logger = clientLogger; /** The server host */ final String serverHost; /** The server port */ final int port; /** * The socket factory for creating plain sockets, or null to use default * sockets. */ final SocketFactory socketFactory; /** The call context specified when the connection was made */ final CallContext callContext; /** * The SSLContext -- only shared by connections with the same host, port, * suite, and principals. */ private final SSLContext sslContext; /** The factory for creating SSL sockets. */ final SSLSocketFactory sslSocketFactory; /** The authentication manager. */ private final ClientAuthManager authManager; /** The socket */ SSLSocket sslSocket; /** The currently active cipher suite */ private String activeCipherSuite; /** The current session */ private SSLSession session; /** True if the connection has been closed. */ boolean closed; /* -- Methods -- */ /** * Creates a connection. * * @param callContext the call context to establish * @param serverHost the server host to connect to * @param port the server port to connect to * @param socketFactory the socket factory, or null to use default sockets */ SslConnection(CallContext callContext, String serverHost, int port, SocketFactory socketFactory) { this.serverHost = serverHost; this.port = port; this.socketFactory = socketFactory; if (callContext == null) { throw new NullPointerException("Call context cannot be null"); } this.callContext = callContext; SSLContextInfo info = getClientSSLContextInfo(callContext); sslContext = info.sslContext; sslSocketFactory = sslContext.getSocketFactory(); authManager = (ClientAuthManager) info.authManager; } /** * Establishes a cipher suite on this connection as specified by the call * context. * * @throws UnsupportedSecurityException if the requested constraints cannot * be supported * @throws IOException if an I/O failure occurs * @throws SecurityException if the current access control context does not * have the proper AuthenticationPermission */ final void establishCallContext() throws IOException { Exception exception; try { establishNewSocket(); if (callContext.clientAuthRequired && !authManager.getClientAuthenticated()) { Exception credExcept = authManager.getClientCredentialException(); /* * Don't throw the exception that occurred when getting client * credentials if the caller doesn't have access to the * subject. */ SecurityManager sm = System.getSecurityManager(); if (sm != null) { try { sm.checkPermission(getSubjectPermission); } catch (SecurityException e) { credExcept = null; } } if (credExcept instanceof SecurityException) { exception = (SecurityException) credExcept; } else { exception = new UnsupportedConstraintException( "Client not authenticated", credExcept); } } else { if (logger.isLoggable(Level.FINE)) { logger.log(Level.FINE, "new connection for {0}\ncreates {1}", new Object[] { callContext, this }); } return; } } catch (SSLProtocolException e) { /* * Don't throw an UnsupportedConstraintException -- this is a * problem within the SSL implementation. */ exception = e; } catch (SSLException e) { exception = new UnsupportedConstraintException(e.getMessage(), e); } catch (IOException e) { exception = e; } catch (SecurityException e) { exception = e; } if (logger.isLoggable(Levels.FAILED)) { logThrow(logger, Levels.FAILED, SslConnection.class, "establishCallContext", "new connection for {0}\nthrows", new Object[] { callContext }, exception); } closeSocket(); if (exception instanceof IOException) { throw (IOException) exception; } else { throw (SecurityException) exception; } } /** Closes the socket for this connection. */ private void closeSocket() { if (sslSocket != null) { try { sslSocket.close(); } catch (IOException e) { } sslSocket = null; session = null; activeCipherSuite = null; } } /** * Attempts to create a new socket for the call context and cipher suites. * * @throws SSLException if the suites cannot be supported * @throws IOException if an I/O failure occurs */ void establishNewSocket() throws IOException { Socket socket = createPlainSocket(serverHost, port); sslSocket = (SSLSocket) sslSocketFactory.createSocket( socket, serverHost, port, /* autoClose */ true); establishSuites(); } /** * Attempts to establish the call context and suites on the current socket. * * @throws SSLException if the requested suites cannot be supported * @throws IOException if an I/O failure occurs */ final void establishSuites() throws IOException { sslSocket.setEnabledCipherSuites(callContext.cipherSuites); sslSocket.startHandshake(); session = sslSocket.getSession(); activeCipherSuite = session.getCipherSuite(); sslSocket.setEnableSessionCreation(false); releaseClientSSLContextInfo(callContext, sslContext, authManager); } /** * Creates a plain socket to use for communication with the specified host * and port. */ final Socket createPlainSocket(String host, int port) throws IOException { Socket socket; if (!callContext.endpointImpl.disableSocketConnect) { /* Connect with proper timeout */ socket = connectToHost(host, port, callContext.connectionTime); } else { socket = newSocket(); } return socket; } private static int computeTimeout(long connectionTime) throws IOException { int timeout; long current = System.currentTimeMillis(); if (connectionTime == -1) { timeout = 0; } else if (connectionTime < current) { throw new IOException("Connection not made within specified time"); } else if (connectionTime - current > Integer.MAX_VALUE) { timeout = 0; } else { timeout = (int) (connectionTime - current); } return timeout; } /** * Returns a socket connected to the specified host and port, * according to the specified constraints. If the host name * resolves to multiple addresses, attempts to connect to each of * them in order until one succeeds. **/ private Socket connectToHost(String host, int port, long connectionTime) throws IOException { InetAddress[] addresses; try { addresses = InetAddress.getAllByName(host); } catch (UnknownHostException uhe) { try { /* * Creating the InetSocketAddress attempts to * resolve the host again; in J2SE 5.0, there is a * factory method for creating an unresolved * InetSocketAddress directly. */ return connectToSocketAddress( new InetSocketAddress(host, port), connectionTime); } catch (IOException e) { if (logger.isLoggable(Levels.FAILED)) { LogUtil.logThrow(logger, Levels.FAILED, SslConnection.class, "connectToHost", "exception connecting to unresolved host {0}", new Object[] { host + ":" + port }, e); } throw e; } catch (SecurityException e) { if (logger.isLoggable(Levels.FAILED)) { LogUtil.logThrow(logger, Levels.FAILED, SslConnection.class, "connectToHost", "exception connecting to unresolved host {0}", new Object[] { host + ":" + port }, e); } throw e; } } catch (SecurityException e) { if (logger.isLoggable(Levels.FAILED)) { LogUtil.logThrow(logger, Levels.FAILED, SslConnection.class, "connectToHost", "exception resolving host {0}", new Object[] { host }, e); } throw e; } IOException lastIOException = null; SecurityException lastSecurityException = null; for (int i = 0; i < addresses.length; i++) { SocketAddress socketAddress = new InetSocketAddress(addresses[i], port); try { return connectToSocketAddress(socketAddress, connectionTime); } catch (IOException e) { if (logger.isLoggable(Levels.HANDLED)) { LogUtil.logThrow(logger, Levels.HANDLED, SslConnection.class, "connectToHost", "exception connecting to {0}", new Object[] { socketAddress }, e); } lastIOException = e; if (e instanceof SocketTimeoutException) { break; } } catch (SecurityException e) { if (logger.isLoggable(Levels.HANDLED)) { LogUtil.logThrow(logger, Levels.HANDLED, SslConnection.class, "connectToHost", "exception connecting to {0}", new Object[] { socketAddress }, e); } lastSecurityException = e; } } if (lastIOException != null) { if (logger.isLoggable(Levels.FAILED)) { LogUtil.logThrow(logger, Levels.FAILED, SslConnection.class, "connectToHost", "exception connecting to {0}", new Object[] { host + ":" + port }, lastIOException); } throw lastIOException; } assert lastSecurityException != null; if (logger.isLoggable(Levels.FAILED)) { LogUtil.logThrow(logger, Levels.FAILED, SslConnection.class, "connectToHost", "exception connecting to {0}", new Object[] { host + ":" + port }, lastSecurityException); } throw lastSecurityException; } /** * Returns a socket connected to the specified address, with a * timeout governed by the specified absolute connection time. **/ private Socket connectToSocketAddress(SocketAddress socketAddress, long connectionTime) throws IOException { int timeout = computeTimeout(connectionTime); Socket socket = newSocket(); boolean ok = false; try { socket.connect(socketAddress, timeout); ok = true; return socket; } finally { if (!ok) { try { socket.close(); } catch (IOException e) { } } } } /** * Returns a new unconnected socket, using this endpoint's * socket factory if non-null. **/ private Socket newSocket() throws IOException { Socket socket = socketFactory != null ? socketFactory.createSocket() : new Socket(); /* Send data without delay */ try { socket.setTcpNoDelay(true); } catch (SocketException e) { } /* Send periodic pings so we can tell if the connection dies. */ try { socket.setKeepAlive(true); } catch (SocketException e) { } return socket; } /** Returns a string representation of this object. */ public String toString() { String sessionString = (session == null) ? "" : session + ", "; return getClassName(this) + "[" + sessionString + (sslSocket == null ? "???" : Integer.toString(sslSocket.getLocalPort())) + "=>" + serverHost + ":" + port + "]"; } /* -- Implement Connection -- */ /* inherit javadoc */ public InputStream getInputStream() throws IOException { if (sslSocket != null) { return sslSocket.getInputStream(); } else { throw new IOException("No socket established"); } } /* inherit javadoc */ public OutputStream getOutputStream() throws IOException { if (sslSocket != null) { return sslSocket.getOutputStream(); } else { throw new IOException("No socket established"); } } /* inherit javadoc */ public SocketChannel getChannel() { return null; } /* inherit javadoc */ public void populateContext(OutboundRequestHandle handle, Collection context) { CallContext.coerce(handle, callContext.endpoint); if (context == null) { throw new NullPointerException("Context cannot be null"); } /* No context info */ } /* inherit javadoc */ public InvocationConstraints getUnfulfilledConstraints( OutboundRequestHandle handle) { CallContext callContext = CallContext.coerce( handle, this.callContext.endpoint); return callContext.getUnfulfilledConstraints(); } /* inherit javadoc */ public void writeRequestData(OutboundRequestHandle handle, OutputStream stream) { CallContext.coerce(handle, callContext.endpoint); if (stream == null) { throw new NullPointerException("Stream cannot be null"); } /* No per-request data needed */ } /* inherit javadoc */ public IOException readResponseData(OutboundRequestHandle handle, InputStream stream) { CallContext.coerce(handle, callContext.endpoint); if (stream == null) { throw new NullPointerException("Stream cannot be null"); } /* No per-response data needed */ return null; } /* inherit javadoc */ public synchronized void close() throws IOException { if (!closed) { logger.log(Level.FINE, "closing {0}", this); closed = true; closeSocket(); } } /** * Returns true if this connection is compatible with the specified call * context. */ final boolean useFor(CallContext otherCallContext) { assert callContext.endpoint.equals(otherCallContext.endpoint); if (logger.isLoggable(Level.FINEST)) { logger.log(Level.FINEST, "try {0}\nwith {1}\nfor {2}", new Object[] { this, callContext, otherCallContext }); } /* Check that connection is established */ if (session == null) { logger.log(Level.FINEST, "connection {0} is not established", this); return false; } /* Check if session is expired */ if (checkSessionExpired()) { logger.log(Level.FINE, "connection {0} session is expired", this); return false; } /* Check client subject -- only use if both specified and '==' */ if (callContext.clientSubject != otherCallContext.clientSubject) { logger.log(Level.FINEST, "connection has wrong subject"); return false; } /* Check client principals */ X500Principal clientPrincipal = authManager.getClientPrincipal(); if (clientPrincipal == null) { if (otherCallContext.clientAuthRequired) { logger.log(Level.FINEST, "connection has no client authentication"); return false; } } else if (otherCallContext.clientPrincipals != null && !otherCallContext.clientPrincipals.contains(clientPrincipal)) { logger.log(Level.FINEST, "connection has wrong client principal"); return false; } /* Check server principals */ X500Principal serverPrincipal = authManager.getServerPrincipal(); if (serverPrincipal != null && otherCallContext.serverPrincipals != null && !otherCallContext.serverPrincipals.contains(serverPrincipal)) { logger.log(Level.FINEST, "connection has wrong server principal"); return false; } /* Check that active suite is one of the requested suites */ String[] requestedSuites = otherCallContext.cipherSuites; int requestedPos = position(activeCipherSuite, requestedSuites); if (requestedPos < 0) { logger.log(Level.FINEST, "connection has wrong suite"); return false; } /* * Check that suites that would be better than the suite active on the * connection are also better for this connection's call context, * meaning that they probably wouldn't have worked anyway. */ String[] connectionSuites = callContext.cipherSuites; int connectionPos = position(activeCipherSuite, connectionSuites); assert connectionPos >= 0 : "Couldn't find connection suite"; for (int i = requestedPos; --i >= 0; ) { String suite = requestedSuites[i]; int p = position(suite, connectionSuites); if (p < 0 || p >= connectionPos) { logger.log(Level.FINEST, "connection did not try all better suites"); return false; } } /* Check client authentication credentials */ if (clientPrincipal != null) { Exception exception; try { authManager.checkAuthentication(); exception = null; } catch (SecurityException e) { exception = e; } catch (UnsupportedConstraintException e) { exception = e; } if (exception != null) { if (logger.isLoggable(Level.FINEST)) { logThrow(logger, Level.FINEST, SslConnection.class, "useFor", "connection {0} has missing subject credentials", new Object[] { this }, exception); } return false; } } /* Looks OK */ logger.log(Level.FINEST, "connection OK"); return true; } /** * Checks if the session currently active on the connection has been active * for longer than maxClientSessionDuration and, if so, invalidates the * session. */ private boolean checkSessionExpired() { long create = session.getCreationTime(); long expiration = create + maxClientSessionDuration; /* Check for rollover */ if (expiration < create) { expiration = Long.MAX_VALUE; } if (expiration <= System.currentTimeMillis()) { session.invalidate(); return true; } else { return false; } } /** * Return HTTP proxy host if present, an empty string otherwise. */ protected String getProxyHost() { return ""; } void checkConnectPermission() { SecurityManager sm = System.getSecurityManager(); if (sm != null) { Socket socket = sslSocket; // This depends on the SslSocket returning information about // its underlying plain socket. InetSocketAddress address = (InetSocketAddress) socket.getRemoteSocketAddress(); if (address.isUnresolved()) { sm.checkConnect(address.getHostName(), socket.getPort()); } else { sm.checkConnect(address.getAddress().getHostAddress(), socket.getPort()); } } } }