/* * Copyright (c) 2001-2007 Sun Microsystems, Inc. All rights reserved. * * The Sun Project JXTA(TM) Software License * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. The end-user documentation included with the redistribution, if any, must * include the following acknowledgment: "This product includes software * developed by Sun Microsystems, Inc. for JXTA(TM) technology." * Alternately, this acknowledgment may appear in the software itself, if * and wherever such third-party acknowledgments normally appear. * * 4. The names "Sun", "Sun Microsystems, Inc.", "JXTA" and "Project JXTA" must * not be used to endorse or promote products derived from this software * without prior written permission. For written permission, please contact * Project JXTA at http://www.jxta.org. * * 5. Products derived from this software may not be called "JXTA", nor may * "JXTA" appear in their name, without prior written permission of Sun. * * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES, * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SUN * MICROSYSTEMS OR ITS CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, * EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * * JXTA is a registered trademark of Sun Microsystems, Inc. in the United * States and other countries. * * Please see the license information page at : * <http://www.jxta.org/project/www/license.html> for instructions on use of * the license in source files. * * ==================================================================== * * This software consists of voluntary contributions made by many individuals * on behalf of Project JXTA. For more information on Project JXTA, please see * http://www.jxta.org. * * This license is based on the BSD license adopted by the Apache Foundation. */ package net.jxta.impl.endpoint.tls; import net.jxta.endpoint.EndpointAddress; import net.jxta.endpoint.Message; import net.jxta.endpoint.Messenger; import net.jxta.endpoint.WireFormatMessage; import net.jxta.endpoint.WireFormatMessageFactory; import net.jxta.impl.membership.pse.PSECredential; import net.jxta.impl.util.TimeUtils; import net.jxta.logging.Logging; import net.jxta.util.IgnoreFlushFilterOutputStream; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSocket; import java.io.BufferedOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.security.KeyStore; import java.security.KeyStoreException; import java.security.Principal; import java.security.Provider; import java.security.Security; import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Enumeration; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; /** * This class implements the TLS connection between two peers. * * * <p/>Properties: * * <p/>net.jxta.impl.endpoint.tls.TMFAlgorithm - if defined provides the name of * the trust manager factory algorithm to use. */ class TlsConn { /** * Logger */ private static final transient Logger LOG = Logger.getLogger(TlsConn.class.getName()); static final int BOSIZE = 16000; /** * TLS transport this connection is working for. */ final TlsTransport transport; /** * The address of the peer to which we will be forwarding ciphertext * messages. */ final EndpointAddress destAddr; /** * Are we client or server? */ private boolean client; /** * State of the connection */ private volatile HandshakeState currentState; /** * Are we currently closing? To prevent recursion in {@link #close(HandshakeState)} */ private boolean closing = false; /** * Time that something "good" last happened on the connection */ long lastAccessed; final String lastAccessedLock = new String("lastAccessedLock"); final String closeLock = new String("closeLock"); /** * Number of retransmissions we have received. */ int retrans; /** * Our synthetic socket which sends and receives the ciphertext. */ final TlsSocket tlsSocket; private final SSLContext context; /** * For interfacing with TLS */ private SSLSocket ssls; /** * We write our plaintext to this stream */ private OutputStream plaintext_out = null; /** * Reads plaintext from the */ private PlaintextMessageReader readerThread = null; /** * A string which we can lock on while acquiring new messengers. We don't * want to lock the whole connection object. */ private final String acquireMessengerLock = new String("Messenger Acquire Lock"); /** * Cached messenger for sending to {@link #destAddr} */ private Messenger outBoundMessenger = null; /** * Tracks the state of our TLS connection with a remote peer. */ enum HandshakeState { /** * Handshake is ready to begin. We will be the client side. */ CLIENTSTART, /** * Handshake is ready to begin. We will be the server side. */ SERVERSTART, /** * Handshake is in progress. */ HANDSHAKESTARTED, /** * Handshake failed to complete. */ HANDSHAKEFAILED, /** * Handshake completed successfully. */ HANDSHAKEFINISHED, /** * Connection is closing. */ CONNECTIONCLOSING, /** * Connection has died. */ CONNECTIONDEAD } /** * Create a new connection */ TlsConn(TlsTransport tp, EndpointAddress destAddr, boolean client) throws Exception { this.transport = tp; this.destAddr = destAddr; this.client = client; this.currentState = client ? HandshakeState.CLIENTSTART : HandshakeState.SERVERSTART; this.lastAccessed = TimeUtils.timeNow(); if (Logging.SHOW_INFO && LOG.isLoggable(Level.INFO)) { LOG.info((client ? "Initiating" : "Accepting") + " new connection for : " + destAddr.getProtocolAddress()); } boolean choseTMF = false; javax.net.ssl.TrustManagerFactory tmf = null; String overrideTMF = System.getProperty("net.jxta.impl.endpoint.tls.TMFAlgorithm"); if (null != overrideTMF) { tmf = javax.net.ssl.TrustManagerFactory.getInstance(overrideTMF); choseTMF = true; } Collection<Provider> providers = Arrays.asList(Security.getProviders()); Set<String> providerNames = new HashSet<String>(); for (Provider provider : providers) { providerNames.add((provider).getName()); } if ((!choseTMF) && providerNames.contains("SunJSSE")) { tmf = javax.net.ssl.TrustManagerFactory.getInstance("SunX509", "SunJSSE"); choseTMF = true; } if ((!choseTMF) && providerNames.contains("IBMJSSE")) { tmf = javax.net.ssl.TrustManagerFactory.getInstance("IbmX509", "IBMJSSE"); choseTMF = true; } // XXX 20040830 bondolo Other solutions go here! if (!choseTMF) { tmf = javax.net.ssl.TrustManagerFactory.getInstance(javax.net.ssl.TrustManagerFactory.getDefaultAlgorithm()); LOG.warning("Using defeualt Trust Manager Factory algorithm. This may not work as expected."); } KeyStore trusted = transport.membership.getPSEConfig().getKeyStore(); tmf.init(trusted); javax.net.ssl.TrustManager[] tms = tmf.getTrustManagers(); javax.net.ssl.KeyManager[] kms = new javax.net.ssl.KeyManager[]{new PSECredentialKeyManager(transport.credential, trusted)}; context = SSLContext.getInstance("TLS"); context.init(kms, tms, null); javax.net.ssl.SSLSocketFactory factory = context.getSocketFactory(); // endpoint interface TlsSocket newConnect = new TlsSocket(new JTlsInputStream(this, tp.MIN_IDLE_RECONNECT), new JTlsOutputStream(transport, this)); // open SSL socket and do the handshake ssls = (SSLSocket) factory.createSocket(newConnect, destAddr.getProtocolAddress(), JTlsDefs.FAKEPORT, true); ssls.setEnabledProtocols(new String[]{"TLSv1"}); ssls.setUseClientMode(client); if (!client) { ssls.setNeedClientAuth(true); } // We have to delay initialization of this until we have set the // handshake mode. tlsSocket = newConnect; } /** * @inheritDoc <p/>An implementation which is useful for debugging. */ @Override public String toString() { return super.toString() + "/" + getHandshakeState() + ":" + (client ? "Client" : "Server") + " for " + destAddr; } /** * Returns the current state of the connection * * @return the current state of the connection. */ HandshakeState getHandshakeState() { return currentState; } /** * Changes the state of the connection. Calls * {@link java.lang.Object#notifyAll()} to wake any threads waiting on * connection state changes. * * @param newstate the new connection state. * @return the previous state of the connection. */ synchronized HandshakeState setHandshakeState(HandshakeState newstate) { HandshakeState oldstate = currentState; currentState = newstate; notifyAll(); return oldstate; } /** * Open the connection with the remote peer. * @throws java.io.IOException if handshake fails */ void finishHandshake() throws IOException { long startTime = 0; if (Logging.SHOW_INFO && LOG.isLoggable(Level.INFO)) { startTime = TimeUtils.timeNow(); LOG.info((client ? "Client:" : "Server:") + " Handshake START"); } setHandshakeState(HandshakeState.HANDSHAKESTARTED); // this starts a handshake SSLSession newSession = ssls.getSession(); if ("SSL_NULL_WITH_NULL_NULL".equals(newSession.getCipherSuite())) { setHandshakeState(HandshakeState.HANDSHAKEFAILED); throw new IOException("Handshake failed"); } setHandshakeState(HandshakeState.HANDSHAKEFINISHED); if (Logging.SHOW_INFO && LOG.isLoggable(Level.INFO)) { long hsTime = TimeUtils.toRelativeTimeMillis(TimeUtils.timeNow(), startTime) / TimeUtils.ASECOND; LOG.info((client ? "Client:" : "Server:") + "Handshake DONE in " + hsTime + " secs"); } // set up plain text i/o // writes to be encrypted plaintext_out = new BufferedOutputStream(ssls.getOutputStream(), BOSIZE); // Start reader thread readerThread = new PlaintextMessageReader(ssls.getInputStream()); } /** * Close this connection. * * @param finalstate state that the connection will be in after close. * @throws java.io.IOException if an error occurs */ void close(HandshakeState finalstate) throws IOException { synchronized (lastAccessedLock) { lastAccessed = Long.MIN_VALUE; } synchronized (closeLock) { closing = true; if (Logging.SHOW_INFO && LOG.isLoggable(Level.INFO)) { LOG.info("Shutting down " + this); } setHandshakeState(HandshakeState.CONNECTIONCLOSING); try { if (null != tlsSocket) { try { tlsSocket.close(); } catch (IOException ignored) { ; } } if (null != ssls) { try { ssls.close(); } catch (IOException ignored) { ; } ssls = null; } if (null != outBoundMessenger) { outBoundMessenger.close(); outBoundMessenger = null; } } catch (Throwable failed) { if (Logging.SHOW_INFO && LOG.isLoggable(Level.INFO)) { LOG.log(Level.INFO, "Throwable during close " + this, failed); } IOException failure = new IOException("Throwable during close()"); failure.initCause(failed); throw failure; } finally { closeLock.notifyAll(); closing = false; setHandshakeState(finalstate); } } } /** * Used by the TlsManager and the TlsConn in order to send a message, * either a TLS connection establishment, or TLS fragments to the remote TLS. * * @param msg message to send to the remote TLS peer. * @return if true then message was sent, otherwise false. * @throws IOException if there was a problem sending the message. */ boolean sendToRemoteTls(Message msg) throws IOException { synchronized (acquireMessengerLock) { if ((null == outBoundMessenger) || outBoundMessenger.isClosed()) { if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Getting messenger for " + destAddr); } EndpointAddress realAddr = new EndpointAddress(destAddr, JTlsDefs.ServiceName, null); // Get a messenger. outBoundMessenger = transport.endpoint.getMessenger(realAddr); if (outBoundMessenger == null) { if (Logging.SHOW_WARNING && LOG.isLoggable(Level.WARNING)) { LOG.warning("Could not get messenger for " + realAddr); } return false; } } } if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Sending " + msg + " to " + destAddr); } // Good we have a messenger. Send the message. return outBoundMessenger.sendMessage(msg); } /** * sendMessage is called by the TlsMessenger each time a service or * an application sends a new message over a TLS connection. * IOException is thrown when something goes wrong. * * <p/>The message is encrypted by TLS ultimately calling * JTlsOutputStream.write(byte[], int, int); with the resulting TLS * Record(s). * * @param msg The plaintext message to be sent via this connection. * @throws IOException for errors in sending the message. */ void sendMessage(Message msg) throws IOException { try { WireFormatMessage serialed = WireFormatMessageFactory.toWire(msg, JTlsDefs.MTYPE, null); serialed.sendToStream(new IgnoreFlushFilterOutputStream(plaintext_out)); plaintext_out.flush(); } catch (IOException failed) { if (Logging.SHOW_INFO && LOG.isLoggable(Level.INFO)) { LOG.log(Level.INFO, "Closing " + this + " due to exception ", failed); } close(HandshakeState.CONNECTIONDEAD); throw failed; } } /** * This is our message reader thread. This reads from the plaintext input * stream and dispatches messages received to the endpoint. */ private class PlaintextMessageReader implements Runnable { InputStream ptin = null; Thread workerThread = null; public PlaintextMessageReader(InputStream ptin) { this.ptin = ptin; // start our thread workerThread = new Thread(TlsConn.this.transport.myThreadGroup, this, "JXTA TLS Plaintext Reader for " + TlsConn.this.destAddr); workerThread.setDaemon(true); workerThread.start(); if (Logging.SHOW_INFO && LOG.isLoggable(Level.INFO)) { LOG.info("Started ReadPlaintextMessage thread for " + TlsConn.this.destAddr); } } /** * @inheritDoc */ public void run() { try { while (true) { try { Message msg = WireFormatMessageFactory.fromWire(ptin, JTlsDefs.MTYPE, null); if (null == msg) { break; } // dispatch it to TlsTransport for demuxing if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Dispatching " + msg + " to TlsTransport"); } TlsConn.this.transport.processReceivedMessage(msg); synchronized (TlsConn.this.lastAccessedLock) { TlsConn.this.lastAccessed = TimeUtils.timeNow(); // update idle timer } } catch (IOException iox) { if (Logging.SHOW_WARNING && LOG.isLoggable(Level.WARNING)) { LOG.log(Level.WARNING, "I/O error while reading decrypted Message", iox); } break; } } } catch (Throwable all) { if (Logging.SHOW_SEVERE && LOG.isLoggable(Level.SEVERE)) { LOG.log(Level.SEVERE, "Uncaught Throwable in thread :" + Thread.currentThread().getName(), all); } } finally { workerThread = null; } if (Logging.SHOW_INFO && LOG.isLoggable(Level.INFO)) { LOG.info("Finishing ReadPlaintextMessage thread"); } } } /** * A private key manager which selects based on the key and cert chain found * in a PSE Credential. * * <p/>TODO Promote this class to a full featured interface for all of the * active PSECredentials. Currently the alias "theone" is used to refer to * the */ private static class PSECredentialKeyManager implements javax.net.ssl.X509KeyManager { PSECredential cred; KeyStore trusted; public PSECredentialKeyManager(PSECredential useCred, KeyStore trusted) { this.cred = useCred; this.trusted = trusted; } /** * {@inheritDoc} */ public String chooseClientAlias(String[] keyType, java.security.Principal[] issuers, java.net.Socket socket) { for (String aKeyType : Arrays.asList(keyType)) { String result = checkTheOne(aKeyType, Arrays.asList(issuers)); if (null != result) { return result; } } return null; } /** * Checks to see if a peer that trusts the given issuers would trust the * special alias THE_ONE, returning it if so, and null otherwise. * * @param keyType the type of key a Certificate must use to be considered * @param allIssuers the issuers trusted by the other peer * @return "theone" if one of the Certificates in this peer's PSECredential's * Certificate chain matches the given keyType and one of the issuers, * or <code>null</code> */ private String checkTheOne(String keyType, Collection<java.security.Principal> allIssuers) { List<X509Certificate> certificates = Arrays.asList(cred.getCertificateChain()); for (X509Certificate certificate : certificates) { if (!certificate.getPublicKey().getAlgorithm().equals(keyType)) { continue; } if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("CHECKING: " + certificate.getIssuerX500Principal() + " in " + allIssuers); } if (allIssuers.contains(certificate.getIssuerX500Principal())) { return "theone"; } } return null; } /** * {@inheritDoc} */ public String chooseServerAlias(String keyType, java.security.Principal[] issuers, java.net.Socket socket) { String[] available = getServerAliases(keyType, issuers); if (null != available) { return available[0]; } else { return null; } } /** * {@inheritDoc} */ public X509Certificate[] getCertificateChain(String alias) { if (alias.equals("theone")) { return cred.getCertificateChain(); } else { try { return (X509Certificate[]) trusted.getCertificateChain(alias); } catch (KeyStoreException ignored) { return null; } } } /** * {@inheritDoc} */ public String[] getClientAliases(String keyType, java.security.Principal[] issuers) { List<String> clientAliases = new ArrayList<String>(); try { Enumeration<String> eachAlias = trusted.aliases(); Collection<Principal> allIssuers = null; if (null != issuers) { allIssuers = Arrays.asList(issuers); } while (eachAlias.hasMoreElements()) { String anAlias = eachAlias.nextElement(); if (trusted.isCertificateEntry(anAlias)) { try { X509Certificate aCert = (X509Certificate) trusted.getCertificate(anAlias); if (null == aCert) { // strange... it should have been there... continue; } if (!aCert.getPublicKey().getAlgorithm().equals(keyType)) { continue; } if (null != allIssuers) { if (allIssuers.contains(aCert.getIssuerX500Principal())) { clientAliases.add(anAlias); } } else { clientAliases.add(anAlias); } } catch (KeyStoreException ignored) { ; } } } } catch (KeyStoreException ignored) { ; } return clientAliases.toArray(new String[clientAliases.size()]); } /** * {@inheritDoc} */ public java.security.PrivateKey getPrivateKey(String alias) { if (alias.equals("theone")) { return cred.getPrivateKey(); } else { return null; } } /** * {@inheritDoc} */ public String[] getServerAliases(String keyType, java.security.Principal[] issuers) { if (keyType.equals(cred.getCertificate().getPublicKey().getAlgorithm())) { if (null == issuers) { return new String[]{"theone"}; } else { Collection<Principal> allIssuers = Arrays.asList(issuers); if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Looking for : " + cred.getCertificate().getIssuerX500Principal()); LOG.fine("Issuers : " + allIssuers); java.security.Principal prin = cred.getCertificate().getIssuerX500Principal(); LOG.fine(" Principal Type :" + prin.getClass().getName()); for (Principal issuer : allIssuers) { LOG.fine("Issuer Type : " + issuer.getClass().getName()); LOG.fine("Issuer value : " + issuer); LOG.fine("tmp.equals(prin) : " + issuer.equals(prin)); } } X509Certificate[] chain = cred.getCertificateChain(); for (X509Certificate aCert : Arrays.asList(chain)) { if (allIssuers.contains(aCert.getIssuerX500Principal())) { return new String[]{"theone"}; } } } } return null; } } }