package org.limewire.net;
import java.net.Socket;
import java.net.SocketException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.limewire.concurrent.ThreadExecutor;
import org.limewire.io.IOUtils;
import org.limewire.io.NetworkInstanceUtils;
import org.limewire.io.NetworkUtils;
import org.limewire.util.StringUtils;
import com.google.inject.Inject;
public class ConnectionDispatcherImpl implements ConnectionDispatcher {
private static final Log LOG = LogFactory.getLog(ConnectionDispatcherImpl.class);
/**
* Mapping of first protocol word -> SocketAcceptor
*/
private final Map<String, Delegator> protocols =
Collections.synchronizedMap(new HashMap<String, Delegator>());
/**
* The longest protocol word we understand.
* LOCKING: protocols.
*/
private int longestWordSize = 0;
private final NetworkInstanceUtils networkInstanceUtils;
@Inject
public ConnectionDispatcherImpl(NetworkInstanceUtils networkInstanceUtils) {
this.networkInstanceUtils = networkInstanceUtils;
}
public int getMaximumWordSize() {
synchronized(protocols) {
return longestWordSize; // currently GNUTELLA == 8
}
}
private boolean areAscii(String...words) {
for (String word : words) {
if (!StringUtils.isAsciiOnly(word)) {
return false;
}
}
return true;
}
public void addConnectionAcceptor(ConnectionAcceptor acceptor,
boolean localOnly,
String... words) {
assert areAscii(words) : "not all ascii: " + Arrays.asList(words);
Delegator d = new Delegator(acceptor, localOnly, acceptor.isBlocking());
synchronized(protocols) {
for (int i = 0; i < words.length; i++) {
if (words[i].length() > longestWordSize) {
longestWordSize = words[i].length();
}
protocols.put(words[i],d);
}
}
}
public void removeConnectionAcceptor(String... words) {
synchronized(protocols) {
protocols.keySet().removeAll(Arrays.asList(words));
longestWordSize = 0;
for(String word : protocols.keySet()) {
if (word.length() > longestWordSize) {
longestWordSize = word.length();
}
}
}
}
public boolean isValidProtocolWord(String word) {
return protocols.containsKey(word);
}
public void dispatch(final String word, final Socket client, boolean newThread) {
try {
client.setSoTimeout(0);
} catch(SocketException se) {
LOG.warn("Unable to set soTimeout, closing client", se);
IOUtils.close(client);
return;
}
// try to find someone who understands this protocol
Delegator delegator = protocols.get(word);
// no protocol available to handle this word
if (delegator == null) {
if (LOG.isErrorEnabled())
LOG.error("Unknown protocol: " + word);
IOUtils.close(client);
return;
}
delegator.delegate(word, client, newThread);
}
/**
* Utility wrapper that checks whether the new protocol is
* supposed to be local, and whether the reading should happen
* in a new thread or not.
*/
private class Delegator {
private final ConnectionAcceptor acceptor;
private final boolean localOnly, blocking;
Delegator(ConnectionAcceptor acceptor,
boolean localOnly,
boolean blocking) {
this.acceptor = acceptor;
this.localOnly = localOnly;
this.blocking = blocking;
}
public void delegate(final String word, final Socket sock, boolean newThread) {
boolean localHost = NetworkUtils.isLocalHost(sock);
boolean drop = false;
if (localOnly && !localHost) {
LOG.debug("Dropping because we want a local connection, and this isn't localhost");
drop = true;
}
if (!localOnly && localHost && networkInstanceUtils.isPrivateAddress(sock.getLocalAddress())) {
LOG.debug("Dropping because we want an external connection, and this is localhost");
drop = true;
}
if (drop) {
IOUtils.close(sock);
return;
}
if (blocking && newThread) {
Runnable r = new Runnable() {
public void run() {
acceptor.acceptConnection(word, sock);
}
};
if(LOG.isDebugEnabled())
LOG.debug("Spawning new thread to dispatch: " + word);
ThreadExecutor.startThread(r, "IncomingConnection");
} else {
if(LOG.isDebugEnabled())
LOG.debug("Handling dispatched word: " + word + " in same thread");
acceptor.acceptConnection(word, sock);
}
}
}
}