/* * Copyright (c) 2001-2007 Sun Microsystems, Inc. All rights reserved. * * The Sun Project JXTA(TM) Software License * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. The end-user documentation included with the redistribution, if any, must * include the following acknowledgment: "This product includes software * developed by Sun Microsystems, Inc. for JXTA(TM) technology." * Alternately, this acknowledgment may appear in the software itself, if * and wherever such third-party acknowledgments normally appear. * * 4. The names "Sun", "Sun Microsystems, Inc.", "JXTA" and "Project JXTA" must * not be used to endorse or promote products derived from this software * without prior written permission. For written permission, please contact * Project JXTA at http://www.jxta.org. * * 5. Products derived from this software may not be called "JXTA", nor may * "JXTA" appear in their name, without prior written permission of Sun. * * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES, * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SUN * MICROSYSTEMS OR ITS CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, * EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * * JXTA is a registered trademark of Sun Microsystems, Inc. in the United * States and other countries. * * Please see the license information page at : * <http://www.jxta.org/project/www/license.html> for instructions on use of * the license in source files. * * ==================================================================== * * This software consists of voluntary contributions made by many individuals * on behalf of Project JXTA. For more information on Project JXTA, please see * http://www.jxta.org. * * This license is based on the BSD license adopted by the Apache Foundation. */ package net.jxta.impl.endpoint.tls; import net.jxta.endpoint.EndpointAddress; import net.jxta.endpoint.EndpointListener; import net.jxta.endpoint.Message; import net.jxta.endpoint.MessageElement; import net.jxta.impl.endpoint.tls.TlsConn.HandshakeState; import net.jxta.impl.util.TimeUtils; import net.jxta.logging.Logging; import java.io.DataInputStream; import java.io.IOException; import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.logging.Level; import java.util.logging.Logger; /** * Manages the connection pool between peers. **/ class TlsManager implements EndpointListener { /** * Log4J Logger **/ private final static transient Logger LOG = Logger.getLogger(TlsManager.class.getName()); /** * Transport we are working for. **/ private TlsTransport transport = null; /** * Hash table for known connections * * <ul> * <li>keys are {@link String } containing {@link net.jxta.peer.PeerID#getUniqueValue() PeerID.getUniqueValue()}</li> * <li>values are {@link TlsConn}<li> * </ul> **/ private final Map<String, TlsConn> connections = new HashMap<String, TlsConn>(); /** * The last time at which we printed a warning about discarding messages * due to no authentication. **/ private long lastNonAuthenticatedWarning = 0; /** * Standard Constructor for TLS Manager **/ TlsManager(TlsTransport tp) { this.transport = tp; } /** * Close this manager. This involves closing all registered connections. * **/ void close() { if (Logging.SHOW_INFO && LOG.isLoggable(Level.INFO)) { LOG.info("Shutting down all connections"); } synchronized (connections) { Iterator<TlsConn> eachConnection = connections.values().iterator(); while (eachConnection.hasNext()) { TlsConn aConnection = eachConnection.next(); try { aConnection.close(HandshakeState.CONNECTIONDEAD); } catch (IOException ignored) { if (Logging.SHOW_INFO && LOG.isLoggable(Level.INFO)) { LOG.info("Non-fatal problem shutting down connection to " + aConnection); } } eachConnection.remove(); } } } /** * Returns or creates a TLS Connection to the specified peer. If an * existing connection exists, it will be returned. * * @param dstAddr the EndpointAddress of the remote peer. * @return A TLS Connection or null if the connection could not be opened. **/ TlsConn getTlsConn(EndpointAddress dstAddr) { if (null == transport.credential) { if (Logging.SHOW_WARNING && LOG.isLoggable(Level.WARNING)) { LOG.warning("Not authenticated. Cannot open connections."); } return null; } boolean startHandshake = false; // see if we have an existing conn, and if so, then reuse it // if it has not timed out. String paddr = dstAddr.getProtocolAddress(); TlsConn conn = null; synchronized (connections) { conn = connections.get(paddr); // remove it if it is dead if (null != conn) { if ((HandshakeState.CONNECTIONDEAD == conn.getHandshakeState()) || (HandshakeState.HANDSHAKEFAILED == conn.getHandshakeState())) { if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Removing connection for: " + paddr); } connections.remove(paddr); conn = null; } } // create the connection info entry as needed if (null == conn) { try { conn = new TlsConn(transport, dstAddr, true); // true means client } catch (Exception failed) { if (Logging.SHOW_WARNING && LOG.isLoggable(Level.WARNING)) { LOG.log(Level.WARNING, "Failed making connection to " + paddr, failed); } return null; } if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Adding connection for: " + paddr); } connections.put(paddr, conn); startHandshake = true; } } // if we got to be the first one to start the handshake then do it here. // We do this outside of the synchro block so that others can enter the // state machine. if (startHandshake) { try { // OK. We are originating the connection: // Open the connection (returns when handshake is completed) // or throws an IOException if a TLS internal error occurs. if (Logging.SHOW_INFO && LOG.isLoggable(Level.INFO)) { LOG.info("Start of client handshake for " + paddr); } conn.finishHandshake(); } catch (Throwable e) { if (Logging.SHOW_WARNING && LOG.isLoggable(Level.WARNING)) { LOG.log(Level.WARNING, "Failed making connection to " + paddr, e); } synchronized (connections) { if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Removing connection for: " + paddr); } connections.remove(paddr); } try { conn.close(HandshakeState.HANDSHAKEFAILED); } catch (IOException ignored) { ; } return null; } } do { if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("getting " + conn); } synchronized (conn) { HandshakeState currentState = conn.getHandshakeState(); if ((HandshakeState.SERVERSTART == currentState) || (HandshakeState.CLIENTSTART == currentState)) { // wait for the handshake to get going on another thread. if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Sleeping until handshake starts for " + paddr); } try { conn.wait(TimeUtils.ASECOND); } catch (InterruptedException woken) { Thread.interrupted(); } } else if (HandshakeState.HANDSHAKESTARTED == currentState) { if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Handshake in progress for " + paddr); } try { // sleep forever waiting for the state to change. conn.wait(200); } catch (InterruptedException woken) { Thread.interrupted(); } } else if (HandshakeState.HANDSHAKEFINISHED == currentState) { if (Logging.SHOW_INFO && LOG.isLoggable(Level.INFO)) { LOG.info("Returning active connection to " + paddr); } conn.lastAccessed = TimeUtils.timeNow(); // update idle timer return conn; } else if (HandshakeState.HANDSHAKEFAILED == currentState) { if (Logging.SHOW_WARNING && LOG.isLoggable(Level.WARNING)) { LOG.warning("Handshake failed. " + paddr + " unreachable"); } return null; } else if (HandshakeState.CONNECTIONDEAD == currentState) { if (Logging.SHOW_WARNING && LOG.isLoggable(Level.WARNING)) { LOG.warning("Connection dead for " + paddr); } return null; } else if (HandshakeState.CONNECTIONCLOSING == currentState) { if (Logging.SHOW_WARNING && LOG.isLoggable(Level.WARNING)) { LOG.warning("Connection closing for " + paddr); } return null; } else { if (Logging.SHOW_SEVERE && LOG.isLoggable(Level.SEVERE)) { LOG.severe("Unhandled Handshake state: " + currentState); } } } } while (true); } /** * Handle an incoming message from the endpoint. This method demultiplexes * incoming messages to the connection objects by their source address. * * <p/>Several types of messages may be received for a connection: * * <ul> * <li>TLS Elements</li> * <li>Element Acknowledgements</li> * </ul> * * @param msg is the incoming message * @param srcAddr is the address of the source of the message * @param dstAddr is the address of the destination of the message **/ public void processIncomingMessage(Message msg, EndpointAddress srcAddr, EndpointAddress dstAddr) { if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Starts for " + msg); } if (null == transport.credential) { // ignore ALL messages until we are authenticated. if (TimeUtils.toRelativeTimeMillis(TimeUtils.timeNow(), lastNonAuthenticatedWarning) > TimeUtils.AMINUTE) { if (Logging.SHOW_WARNING && LOG.isLoggable(Level.WARNING)) { LOG.warning("NOT AUTHENTICATED--Discarding all incoming messages"); } lastNonAuthenticatedWarning = TimeUtils.timeNow(); } return; } // determine if its a retry. MessageElement retryElement = msg.getMessageElement(JTlsDefs.TLSNameSpace, JTlsDefs.RETR); boolean retrans = (null != retryElement); if (retrans) { msg.removeMessageElement(retryElement); retryElement = null; } int seqN = getMsgSequenceNumber(msg); // Extract unique part of source address String paddr = srcAddr.getProtocolAddress(); TlsConn conn = null; boolean serverStart = false; synchronized (connections) { // Will be in our hash table unless this is for a first time // incoming connection request conn = connections.get(paddr); if (null != conn) { // check if the connection has idled out and remote is asking for a restart. if (TlsTransport.ACT_AS_SERVER && (1 == seqN)) { synchronized (conn) { long idle = TimeUtils.toRelativeTimeMillis(TimeUtils.timeNow(), conn.lastAccessed); if (idle > transport.MIN_IDLE_RECONNECT) { if (Logging.SHOW_WARNING && LOG.isLoggable(Level.WARNING)) { LOG.warning("Restarting : " + conn + " which has been idle for " + idle + " millis"); } try { conn.close(HandshakeState.CONNECTIONDEAD); } catch (IOException ignored) { ; } } } } // remove it if it is dead if ((HandshakeState.CONNECTIONDEAD == conn.getHandshakeState()) || (HandshakeState.HANDSHAKEFAILED == conn.getHandshakeState())) { if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Removing connection for: " + paddr); } connections.remove(paddr); conn = null; } } // we don't have a connection to this destination, make a new connection if seqn#1 if (null == conn) { if (TlsTransport.ACT_AS_SERVER && (1 == seqN)) { try { conn = new TlsConn(transport, srcAddr, false); // false means Server } catch (Exception failed) { if (Logging.SHOW_WARNING && LOG.isLoggable(Level.WARNING)) { LOG.log(Level.WARNING, "Failed making connection for" + paddr, failed); } return; } if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Adding connection for: " + paddr); } connections.put(paddr, conn); serverStart = true; } else { // Garbage from an old connection. discard it if (Logging.SHOW_WARNING && LOG.isLoggable(Level.WARNING)) { LOG.warning(msg + " is not start of handshake (seqn#" + seqN + ") for " + paddr); } msg.clear(); return; } } } // if this is a new connection, get it started. if (serverStart) { try { if (Logging.SHOW_INFO && LOG.isLoggable(Level.INFO)) { LOG.info("Start of SERVER handshake for " + paddr); } // Queue message up for TlsInputStream on that connection conn.tlsSocket.input.queueIncomingMessage(msg); // Start the TLS Server and complete the handshake conn.finishHandshake(); // open the TLS connection conn.lastAccessed = TimeUtils.timeNow(); if (Logging.SHOW_INFO && LOG.isLoggable(Level.INFO)) { LOG.info("Handshake complete for SERVER TLS for: " + paddr); } return; } catch (Throwable e) { // Handshake failure or IOException if (Logging.SHOW_WARNING && LOG.isLoggable(Level.WARNING)) { LOG.log(Level.WARNING, "TLS Handshake failure for connection: " + paddr, e); } synchronized (connections) { if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Removing connection for: " + paddr); } connections.remove(paddr); } try { conn.close(HandshakeState.HANDSHAKEFAILED); } catch (IOException ignored) { ; } return; } } // handle an ongoing connection. do { HandshakeState currentState; synchronized (conn) { if (retrans) { conn.retrans++; if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("retrans received, " + conn.retrans + " total."); } retrans = false; } if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Process incoming message for " + conn); } currentState = conn.getHandshakeState(); if ((HandshakeState.HANDSHAKESTARTED == currentState) || (HandshakeState.HANDSHAKEFINISHED == currentState) || (HandshakeState.CONNECTIONCLOSING == currentState)) {// we will process the message once we get out of sync. } else if (HandshakeState.CONNECTIONDEAD == currentState) { // wait for the handshake to get going on another thread. if (Logging.SHOW_INFO && LOG.isLoggable(Level.INFO)) { LOG.info("Connection failed, discarding msg with seqn#" + seqN + " for " + paddr); } return; } else if ((HandshakeState.SERVERSTART == currentState) || (HandshakeState.CLIENTSTART == currentState)) { // wait for the handshake to get going on another thread. if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Sleeping msg with seqn#" + seqN + " until handshake starts for " + paddr); } try { conn.wait(TimeUtils.AMINUTE); } catch (InterruptedException woken) { Thread.interrupted(); } continue; } else if (HandshakeState.HANDSHAKEFAILED == currentState) { // wait for the handshake to get going on another thread. if (Logging.SHOW_INFO && LOG.isLoggable(Level.INFO)) { LOG.info("Handshake failed, discarding msg with seqn#" + seqN + " for " + paddr); } return; } else { if (Logging.SHOW_WARNING && LOG.isLoggable(Level.WARNING)) { LOG.warning("Unexpected state : " + currentState); } } } // Process any message outside of the sync on the connection. if ((HandshakeState.HANDSHAKESTARTED == currentState) || (HandshakeState.HANDSHAKEFINISHED == currentState) || (HandshakeState.CONNECTIONCLOSING == currentState)) { // process any ACK messages. Iterator<MessageElement> eachACK = msg.getMessageElements(JTlsDefs.TLSNameSpace, JTlsDefs.ACKS); while (eachACK.hasNext()) { MessageElement elt = eachACK.next(); eachACK.remove(); int sackCount = ((int) elt.getByteLength() / 4) - 1; try { DataInputStream dis = new DataInputStream(elt.getStream()); int seqack = dis.readInt(); int[] sacs = new int[sackCount]; for (int eachSac = 0; eachSac < sackCount; eachSac++) { sacs[eachSac] = dis.readInt(); } Arrays.sort(sacs); // take care of the ACK here; conn.tlsSocket.output.ackReceived(seqack, sacs); } catch (IOException failed) { if (Logging.SHOW_WARNING && LOG.isLoggable(Level.WARNING)) { LOG.log(Level.WARNING, "Failure processing ACK", failed); } } } if (0 == seqN) { return; } if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Queue " + msg + " seqn#" + seqN + " for " + conn); } // Queue message up for TlsInputStream on that connection TlsSocket bound = conn.tlsSocket; if (null != bound) { bound.input.queueIncomingMessage(msg); } return; } } while (true); } /** * getMsgSequenceNumber * * @param msg Input message * @return int sequence number or 0 (zero) if no tls records in message. **/ private static int getMsgSequenceNumber(Message msg) { int seqN = 0; Iterator<MessageElement> eachElement = msg.getMessageElements(JTlsDefs.TLSNameSpace, JTlsDefs.BLOCKS); while (eachElement.hasNext()) { MessageElement elt = eachElement.next(); try { seqN = Integer.parseInt(elt.getElementName()); } catch (NumberFormatException e) { // This element was not a TLS element. Get the next one if (Logging.SHOW_WARNING && LOG.isLoggable(Level.WARNING)) { LOG.warning("Bad tls record name=" + elt.getElementName()); } eachElement.remove(); continue; } break; } return seqN; } }