package org.limewire.net; import java.io.IOException; import java.net.Socket; import java.nio.ByteBuffer; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.limewire.nio.AbstractNBSocket; import org.limewire.nio.channel.AbstractChannelInterestReader; import org.limewire.nio.ssl.SSLUtils; import org.limewire.util.BufferUtils; import org.limewire.util.StringUtils; /** * A ConnectionDispatcher that reads asynchronously from the socket. */ public class AsyncConnectionDispatcher extends AbstractChannelInterestReader { private static final Log LOG = LogFactory.getLog(AsyncConnectionDispatcher.class); private final ConnectionDispatcher dispatcher; private final Socket socket; private final String allowedWord; private final boolean allowTLS; private boolean finished = false; public AsyncConnectionDispatcher(ConnectionDispatcher dispatcher, Socket socket, String allowedWord, boolean allowTLS) { // + 1 for whitespace super(dispatcher.getMaximumWordSize() + 1); if (socket == null) { throw new IllegalArgumentException(); } this.dispatcher = dispatcher; this.socket = socket; this.allowedWord = allowedWord; this.allowTLS = allowTLS; } public void handleRead() throws IOException { // If we already finished our reading, turn read interest off // and exit early. if(finished) { source.interestRead(false); return; } // Fill up our buffer as much we can. int read = 0; while(buffer.hasRemaining() && (read = source.read(buffer)) > 0); // See if we have a full word. for(int i = 0; i < buffer.position(); i++) { if(buffer.get(i) == ' ') { String word = StringUtils.getASCIIString(buffer.array(), 0, i); if(dispatcher.isValidProtocolWord(word)) { if(allowedWord != null && !allowedWord.equals(word)) { if(LOG.isDebugEnabled()) LOG.debug("Legal but wrong word: " + word); throw new IOException("wrong word!"); } if(LOG.isDebugEnabled()) LOG.debug("Dispatching word: " + word); buffer.limit(buffer.position()).position(i+1); source.interestRead(false); dispatcher.dispatch(word, socket, true); } else { startTLS(); } finished = true; return; } } // If there's no room to read more or there's nothing left to read, // we aren't going to read our word. Attempt to switch to TLS, or // close if we EOF'd early. if(!buffer.hasRemaining()) { startTLS(); finished = true; return; } else if(read == -1) { close(); return; } } /** * Attempts to start TLS encoding on the socket. * If any data was buffered but not used, the data will be read as part * of the TLS handshake. If the socket is not capable of switching to TLS, * the socket is closed. * * @throws IOException if there was an error starting TLS */ private void startTLS() throws IOException { if (allowTLS && !SSLUtils.isTLSEnabled(socket) && SSLUtils.isStartTLSCapable(socket)) { LOG.debug("Attempting to start TLS"); buffer.flip(); AbstractNBSocket tlsSocket = SSLUtils.startTLS(socket, buffer); tlsSocket.setReadObserver(new AsyncConnectionDispatcher(dispatcher, tlsSocket, allowedWord, allowTLS)); } else { close(); } } @Override public int read(ByteBuffer dst) { return BufferUtils.transfer(buffer, dst, false); } @Override public long read(ByteBuffer [] dst) { return BufferUtils.transfer(buffer, dst, 0, dst.length, false); } @Override public long read(ByteBuffer [] dst, int offset, int length) { return BufferUtils.transfer(buffer, dst, offset, length, false); } }