/* * Copyright (c) 2001-2007 Sun Microsystems, Inc. All rights reserved. * * The Sun Project JXTA(TM) Software License * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. The end-user documentation included with the redistribution, if any, must * include the following acknowledgment: "This product includes software * developed by Sun Microsystems, Inc. for JXTA(TM) technology." * Alternately, this acknowledgment may appear in the software itself, if * and wherever such third-party acknowledgments normally appear. * * 4. The names "Sun", "Sun Microsystems, Inc.", "JXTA" and "Project JXTA" must * not be used to endorse or promote products derived from this software * without prior written permission. For written permission, please contact * Project JXTA at http://www.jxta.org. * * 5. Products derived from this software may not be called "JXTA", nor may * "JXTA" appear in their name, without prior written permission of Sun. * * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES, * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SUN * MICROSYSTEMS OR ITS CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, * EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * * JXTA is a registered trademark of Sun Microsystems, Inc. in the United * States and other countries. * * Please see the license information page at : * <http://www.jxta.org/project/www/license.html> for instructions on use of * the license in source files. * * ==================================================================== * * This software consists of voluntary contributions made by many individuals * on behalf of Project JXTA. For more information on Project JXTA, please see * http://www.jxta.org. * * This license is based on the BSD license adopted by the Apache Foundation. */ package net.jxta.impl.endpoint.tls; import net.jxta.endpoint.ByteArrayMessageElement; import net.jxta.endpoint.Message; import net.jxta.endpoint.MessageElement; import net.jxta.impl.util.TimeUtils; import net.jxta.logging.Logging; import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.InterruptedIOException; import java.net.SocketTimeoutException; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Vector; import java.util.logging.Level; import java.util.logging.Logger; /** * Acts as the input for TLS. Accepts ciphertext which arrives in messages * and orders it before passing it to TLS for decryption. * * TLS will do its raw reads off of this InputStream * Here, we will have queued up the payload of TLS message * elements to be passed to TLS code as TLS Records. * */ class JTlsInputStream extends InputStream { private static final Logger LOG = Logger.getLogger(JTlsInputStream.class.getName()); private static final boolean DEBUGIO = false; static private int MAXQUEUESIZE = 25; /** * Connection we are working for. */ private TlsConn conn; private volatile boolean closed = false; private boolean closing = false; private long timeout = 2 * TimeUtils.AMINUTE; private JTlsRecord jtrec = null; private volatile int sequenceNumber = 0; private final Vector<IQElt> inputQueue = new Vector<IQElt>(MAXQUEUESIZE); // For incoming messages. /** * Input TLS record Object **/ private static class JTlsRecord { // This dummy message elt public InputStream tlsRecord; // TLS Record public long nextByte; // next inbuff byte public long size; // size of TLS Record public JTlsRecord() { tlsRecord = null; // allocated by caller nextByte = 0; // We read here (set by caller) size = 0; // TLS Record size(set by caller) } // reset the jxta tls record element public void resetRecord() { if (null != tlsRecord) { try { tlsRecord.close(); } catch (IOException ignored) {// ignored } } tlsRecord = null; size = nextByte = 0; } } // An input queue element which breaks out a // received message in enqueueMessage(). private static class IQElt { int seqnum; MessageElement elt; boolean ackd; } public JTlsInputStream(TlsConn conn, long timeout) { this.timeout = timeout; this.conn = conn; jtrec = new JTlsRecord(); // 1 <= seq# <= maxint, monotonically increasing // Incremented before compare. sequenceNumber = 0; } /** * {@inheritDoc} **/ @Override public void close() throws IOException { super.close(); closed = true; synchronized (inputQueue) { inputQueue.clear(); inputQueue.notifyAll(); } } /** * prepare this input stream to being closed. It will still * deliver the packets that have been received, but nothing * more. This is meant to be called in response to the other side * having initiated closure. We assume that when the other side does it * it means that it is satified with what we have acknoleged so far. */ public void setClosing() throws IOException { synchronized (inputQueue) { closing = true; inputQueue.notifyAll(); } } // Here we read the TLS Record data from the incoming JXTA message. // (We will really have a full jxta message available.) // // TLS Record input only calls the following methods. // They are called from SSLRecord.decode(SSLConn, Inputstream); // /** * {@inheritDoc} */ @Override public int read() throws IOException { if (closed) { return -1; } byte[] a = new byte[1]; while (true) { int len = local_read(a, 0, 1); if (len < 0) { break; } if (len > 0) { if (DEBUGIO && Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Read() : " + (a[0] & 255)); } return (a[0] & 0xFF); // The byte } } // If we've reached EOF, there's nothing to do but close(). close(); return -1; } /** * {@inheritDoc} */ @Override public int read(byte[] a, int offset, int length) throws IOException { if (closed) { return -1; } if (0 == length) { return 0; } int i = local_read(a, offset, length); if (DEBUGIO && Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Read(byte[], int, " + length + "), bytes read = " + i); } // If we've reached EOF; there's nothing to do but close(). if (i == -1) { close(); } return i; } // protected accessor for sequence number int getSequenceNumber() { return sequenceNumber; } // Our input queue max size int getMaxIQSize() { return MAXQUEUESIZE; } /** * Send a sequential ACK and selective ACKs for all of the queued messages. * * @param seqnAck the sequence number being sequential ACKed **/ private void sendACK(int seqnAck) { List<Integer> selectedAckList = new ArrayList<Integer>(); synchronized (inputQueue) { Iterator<IQElt> eachInQueue = inputQueue.iterator(); while (eachInQueue.hasNext() && (selectedAckList.size() < MAXQUEUESIZE)) { IQElt anIQElt = eachInQueue.next(); if (anIQElt.seqnum > seqnAck) { selectedAckList.add(new Integer(anIQElt.seqnum)); } } } // PERMIT DUPLICATE ACKS. Just a list and one small message. sendACK(seqnAck, selectedAckList); } /** * Build an ACK message. The message provides a sequential ACK count and * an optional list of selective ACKs. * * @param seqnAck the sequence number being sequential ACKed * @param sackList a list of selective ACKs. Must be sorted in increasing * order. */ private void sendACK(int seqnAck, List<Integer> sackList) { ByteArrayOutputStream bos = new ByteArrayOutputStream((1 + sackList.size()) * 4); DataOutputStream dos = new DataOutputStream(bos); try { dos.writeInt(seqnAck); Iterator<Integer> eachSACK = sackList.iterator(); while (eachSACK.hasNext()) { int aSack = (eachSACK.next()).intValue(); dos.writeInt(aSack); } dos.close(); bos.close(); Message ACKMsg = new Message(); MessageElement elt = new ByteArrayMessageElement(JTlsDefs.ACKKEY, JTlsDefs.ACKS, bos.toByteArray(), null); ACKMsg.addMessageElement(JTlsDefs.TLSNameSpace, elt); conn.sendToRemoteTls(ACKMsg); if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("SENT ACK, seqn#" + seqnAck + " and " + sackList.size() + " SACKs "); } } catch (IOException e) { if (Logging.SHOW_INFO && LOG.isLoggable(Level.INFO)) { LOG.log(Level.INFO, "sendACK caught IOException:", e); } } } /** * queue messages by sequence number. */ public void queueIncomingMessage(Message msg) { if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Queue Incoming Message begins for " + msg); } long startEnqueue = TimeUtils.timeNow(); Message.ElementIterator e = msg.getMessageElements(JTlsDefs.TLSNameSpace, JTlsDefs.BLOCKS); // OK look for jxta message while (!closed && !closing && e.hasNext()) { MessageElement elt = e.next(); e.remove(); int msgSeqn = 0; try { msgSeqn = Integer.parseInt(elt.getElementName()); } catch (NumberFormatException n) { if (Logging.SHOW_WARNING && LOG.isLoggable(Level.WARNING)) { LOG.warning("Discarding element (" + elt.getElementName() + ") Not one of ours."); } continue; } IQElt newElt = new IQElt(); newElt.seqnum = msgSeqn; newElt.elt = elt; newElt.ackd = false; // OK we must inqueue: // Wait until someone dequeues if we are at the size limit // see if this is a duplicate if (newElt.seqnum <= sequenceNumber) { if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("RCVD OLD MESSAGE : Discard seqn#" + newElt.seqnum + " now at seqn#" + sequenceNumber); } break; } synchronized (inputQueue) { // dbl check with the lock held. if (closing || closed) { return; } // Insert this message into the input queue. // 1. Do not add duplicate messages // 2. Store in increasing sequence nos. int insertIndex = inputQueue.size(); boolean duplicate = false; for (int j = 0; j < inputQueue.size(); j++) { IQElt iq = inputQueue.elementAt(j); if (newElt.seqnum < iq.seqnum) { insertIndex = j; break; } else if (newElt.seqnum == iq.seqnum) { duplicate = true; break; } } if (duplicate) { if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("RCVD OLD MESSAGE : Discard duplicate msg, seqn#" + newElt.seqnum); } newElt = null; break; } inputQueue.add(insertIndex, newElt); if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Enqueued msg with seqn#" + newElt.seqnum + " at index " + insertIndex); } inputQueue.notifyAll(); newElt = null; } } if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { long waited = TimeUtils.toRelativeTimeMillis(TimeUtils.timeNow(), startEnqueue); LOG.fine("Queue Incoming Message for " + msg + " completed in " + waited + " msec."); } } /** * Dequeue the message with the desired sequence number waiting as needed * until the message is available. * * @param desiredSeqn the sequence number to be dequeued. * @return the Message Element with the desired sequence number or null if * the queue has been closed. **/ private MessageElement dequeueMessage(int desiredSeqn) throws IOException { IQElt iQ = null; // Wait for incoming message here long startDequeue = TimeUtils.timeNow(); long whenToTimeout = startDequeue + timeout; int wct = 0; long nextRetransRequest = TimeUtils.toAbsoluteTimeMillis(TimeUtils.ASECOND); synchronized (inputQueue) { while (!closed) { if (inputQueue.size() == 0) { if (closing) { return null; } try { wct++; inputQueue.wait(TimeUtils.ASECOND); if (whenToTimeout < TimeUtils.timeNow()) { throw new SocketTimeoutException("Read timeout reached"); } } catch (InterruptedException e) { Thread.interrupted(); // just continue } // we reset the retrans request timer since we don't want to // immediately request retry after a long wait for out of // order messages. nextRetransRequest = TimeUtils.toAbsoluteTimeMillis(TimeUtils.ASECOND); continue; } iQ = inputQueue.elementAt(0); // FIFO if (iQ.seqnum < desiredSeqn) { // Ooops a DUPE slipped in the head of the queue undetected // (seqnum consistency issue). // Just drop it. inputQueue.remove(0); // if such is the case then notify the other end so that // the message does not remain in the retry queue eventually // triggering a broken pipe exception sendACK(iQ.seqnum); continue; } else if (iQ.seqnum != desiredSeqn) { if (TimeUtils.toRelativeTimeMillis(nextRetransRequest) < 0) { if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("Trigger retransmission. Wanted seqn#" + desiredSeqn + " found seqn#" + iQ.seqnum); } sendACK(desiredSeqn - 1); nextRetransRequest = TimeUtils.toAbsoluteTimeMillis(TimeUtils.ASECOND); } try { wct++; inputQueue.wait(TimeUtils.ASECOND); if (whenToTimeout < TimeUtils.timeNow()) { throw new SocketTimeoutException("Read timeout reached"); } } catch (InterruptedException e) { throw new InterruptedIOException("IO interrupted "); } continue; } inputQueue.remove(0); break; } } nextRetransRequest = 0; sendACK(desiredSeqn); // if we are closed then we return null if (null == iQ) { return null; } if (Logging.SHOW_INFO && LOG.isLoggable(Level.INFO)) { long waited = TimeUtils.toRelativeTimeMillis(TimeUtils.timeNow(), startDequeue); LOG.info("DEQUEUED seqn#" + iQ.seqnum + " in " + waited + " msec on input queue"); if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { if (wct > 0) { LOG.fine("DEQUEUE waited " + wct + " times on input queue"); } } } return iQ.elt; } /** * */ private int local_read(byte[] a, int offset, int length) throws IOException { synchronized (jtrec) { if ((jtrec.size == 0) || (jtrec.nextByte == jtrec.size)) { // reset the record jtrec.resetRecord(); // GC as necessary(tlsRecord byte[]) if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("local_read: getting next data block at seqn#" + (sequenceNumber + 1)); } MessageElement elt = null; try { elt = dequeueMessage(sequenceNumber + 1); } catch (SocketTimeoutException ste) { // timed out with no data // SSLSocket expects a 0 data in this case return 0; } if (null == elt) { return -1; } sequenceNumber += 1; // next msg sequence number // Get the length of the TLS Record jtrec.size = elt.getByteLength(); jtrec.tlsRecord = elt.getStream(); if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("local_read: new seqn#" + sequenceNumber + ", bytes = " + jtrec.size); } } // return the requested TLS Record data // These calls should NEVER ask for more data than is in the // received TLS Record. long left = jtrec.size - jtrec.nextByte; int copyLen = (int) Math.min(length, left); int copied = 0; do { int res = jtrec.tlsRecord.read(a, offset + copied, copyLen - copied); if (res < 0) { break; } copied += res; } while (copied < copyLen); jtrec.nextByte += copied; if (DEBUGIO) { if (Logging.SHOW_FINE && LOG.isLoggable(Level.FINE)) { LOG.fine("local_read: Requested " + length + ", Read " + copied + " bytes"); } } return copied; } } }