/*
* JBoss, Home of Professional Open Source.
* Copyright 2008, Red Hat Middleware LLC, and individual contributors
* as indicated by the @author tags. See the copyright.txt file in the
* distribution for a full listing of individual contributors.
*
* This is free software; you can redistribute it and/or modify it
* under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of
* the License, or (at your option) any later version.
*
* This software is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this software; if not, write to the Free
* Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
* 02110-1301 USA, or see the FSF site: http://www.fsf.org.
*/
package org.jboss.invocation.pooled.interfaces;
import java.io.IOException;
import java.io.Externalizable;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.BufferedOutputStream;
import java.io.BufferedInputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.EOFException;
import java.io.OptionalDataException;
import java.io.UnsupportedEncodingException;
import java.io.InterruptedIOException;
import java.net.Socket;
import java.net.SocketException;
import java.rmi.MarshalledObject;
import java.rmi.NoSuchObjectException;
import java.rmi.ServerException;
import java.rmi.ConnectException;
import java.util.Iterator;
import java.util.Map;
import java.util.List;
import java.util.LinkedList;
import javax.transaction.TransactionRolledbackException;
import javax.transaction.SystemException;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.HandshakeCompletedListener;
import javax.net.ssl.HandshakeCompletedEvent;
import javax.net.ssl.SSLException;
import org.jboss.invocation.Invocation;
import org.jboss.invocation.Invoker;
import org.jboss.tm.TransactionPropagationContextFactory;
import org.jboss.tm.TransactionPropagationContextUtil;
import org.jboss.logging.Logger;
import EDU.oswego.cs.dl.util.concurrent.ConcurrentReaderHashMap;
/**
* Client socket connections are pooled to avoid the overhead of
* making a connection. RMI seems to do a new connection with each
* request.
*
* @author <a href="mailto:bill@jboss.org">Bill Burke</a>
* @author Scott.Stark@jboss.org
* @version $Revision: 81030 $
*/
public class PooledInvokerProxy
implements Invoker, Externalizable
{
// Attributes ----------------------------------------------------
private static final Logger log = Logger.getLogger(PooledInvokerProxy.class);
/** The serialVersionUID @since 1.1.4.3 */
private static final long serialVersionUID = -1456509931095566410L;
/** The current wire format we write */
private static final int WIRE_VERSION = 1;
// Simple performance measurements, not thread safe
public static long getSocketTime = 0;
public static long readTime = 0;
public static long writeTime = 0;
public static long serializeTime = 0;
public static long deserializeTime = 0;
/** The number of times a connection has been obtained from a pool */
public static long usedPooled = 0;
/** The number of connections in use */
private static int inUseCount = 0;
/** The number of socket connections made */
private static long socketConnectCount = 0;
/** The number of socket close calls made */
private static long socketCloseCount = 0;
/**
* Set number of retries in getSocket method
*/
public static int MAX_RETRIES = 10;
/** A class wide pool Map<ServerAddres, LinkedList<ClientSocket>> */
protected static final Map connectionPools = new ConcurrentReaderHashMap();
/**
* connection information
*/
protected ServerAddress address;
/**
* Pool for this invoker. This is shared between all
* instances of proxies attached to a specific invoker
* This should not be serializable, but is for backward compatibility.
*/
protected LinkedList pool = null;
/** */
protected int maxPoolSize;
/** The number of times to retry after seeing a ConnectionException */
protected int retryCount = 1;
/** The logging trace flag */
private transient boolean trace;
/**
* An encapsulation of a client connection
*/
protected static class ClientSocket
implements HandshakeCompletedListener
{
public ObjectOutputStream out;
public ObjectInputStream in;
public Socket socket;
public int timeout;
public String sessionID;
private boolean handshakeComplete = false;
private boolean trace;
public ClientSocket(Socket socket, int timeout) throws Exception
{
this.socket = socket;
trace = log.isTraceEnabled();
boolean needHandshake = false;
if( socket instanceof SSLSocket )
{
SSLSocket ssl = (SSLSocket) socket;
ssl.addHandshakeCompletedListener(this);
if( trace )
log.trace("Starting SSL handshake");
needHandshake = true;
handshakeComplete = false;
ssl.startHandshake();
}
socket.setSoTimeout(timeout);
this.timeout = timeout;
out = new OptimizedObjectOutputStream(new BufferedOutputStream(socket.getOutputStream()));
out.flush();
in = new OptimizedObjectInputStream(new BufferedInputStream(socket.getInputStream()));
if( needHandshake )
{
// Loop waiting for the handshake to complete
socket.setSoTimeout(1000);
for(int n = 0; handshakeComplete == false && n < 60; n ++)
{
try
{
int b = in.read();
}
catch(SSLException e)
{
if( trace )
log.trace("Error while waiting for handshake to complete", e);
throw e;
}
catch(IOException e)
{
if( trace )
log.trace("Handshaked read()", e);
}
}
if( handshakeComplete == false )
throw new SSLException("Handshaked failed to complete in 60 seconds");
// Restore the original timeout
socket.setSoTimeout(timeout);
}
}
public void handshakeCompleted(HandshakeCompletedEvent event)
{
handshakeComplete = true;
byte[] id = event.getSession().getId();
try
{
sessionID = new String(id, "UTF-8");
}
catch (UnsupportedEncodingException e)
{
log.warn("Failed to create session id using UTF-8, using default", e);
sessionID = new String(id);
}
if( trace )
{
log.trace("handshakeCompleted, event="+event+", sessionID="+sessionID);
}
}
public String toString()
{
StringBuffer tmp = new StringBuffer("ClientSocket@");
tmp.append(System.identityHashCode(this));
tmp.append('[');
tmp.append("socket=");
tmp.append(socket.toString());
tmp.append(']');
return tmp.toString();
}
/**
* @todo should this be handled with weak references as this should
* work better with gc
*/
protected void finalize()
{
if (socket != null)
{
if( trace )
log.trace("Closing socket in finalize: "+socket);
try
{
socketCloseCount --;
socket.close();
}
catch (Exception ignored) {}
finally
{
socket = null;
}
}
}
}
/**
* Clear all class level stats
*/
public static void clearStats()
{
getSocketTime = 0;
readTime = 0;
writeTime = 0;
serializeTime = 0;
deserializeTime = 0;
usedPooled = 0;
}
/**
* @return the active number of client connections
*/
public static long getInUseCount()
{
return inUseCount;
}
/**
* @return the number of times a connection was returned from a pool
*/
public static long getUsedPooled()
{
return usedPooled;
}
public static long getSocketConnectCount()
{
return socketConnectCount;
}
public static long getSocketCloseCount()
{
return socketCloseCount;
}
/**
* @return the total number of pooled connections across all ServerAddresses
*/
public static int getTotalPoolCount()
{
int count = 0;
Iterator iter = connectionPools.values().iterator();
while( iter.hasNext() )
{
List pool = (List) iter.next();
if( pool != null )
count += pool.size();
}
return count;
}
/**
* @return the proxy local pool count
*/
public long getPoolCount()
{
return pool.size();
}
/**
* Exposed for externalization.
*/
public PooledInvokerProxy()
{
super();
trace = log.isTraceEnabled();
}
/**
* Create a new Proxy.
*
*/
public PooledInvokerProxy(ServerAddress sa, int maxPoolSize)
{
this(sa, maxPoolSize, MAX_RETRIES);
}
public PooledInvokerProxy(ServerAddress sa, int maxPoolSize, int retryCount)
{
this.address = sa;
this.maxPoolSize = maxPoolSize;
this.retryCount = retryCount;
}
/**
* Close all sockets in a specific pool.
*/
public static void clearPool(ServerAddress sa)
{
boolean trace = log.isTraceEnabled();
if( trace )
log.trace("clearPool, sa: "+sa);
try
{
LinkedList thepool = (LinkedList)connectionPools.get(sa);
if (thepool == null) return;
synchronized (thepool)
{
int size = thepool.size();
for (int i = 0; i < size; i++)
{
ClientSocket cs = null;
try
{
ClientSocket socket = (ClientSocket)thepool.removeFirst();
cs = socket;
if( trace )
log.trace("Closing, ClientSocket: "+socket);
socketCloseCount --;
socket.socket.close();
}
catch (Exception ignored)
{
}
finally
{
if( cs != null )
cs.socket = null;
}
}
}
}
catch (Exception ex)
{
// ignored
}
}
/**
* Close all sockets in all pools
*/
public static void clearPools()
{
synchronized (connectionPools)
{
Iterator it = connectionPools.keySet().iterator();
while (it.hasNext())
{
ServerAddress sa = (ServerAddress)it.next();
clearPool(sa);
}
}
}
public boolean equals(Object other)
{
if(! (other instanceof PooledInvokerProxy))
return false;
return (address.equals( ((PooledInvokerProxy)other).address ));
}
public int hashCode()
{
return address.hashCode();
}
protected void initPool()
{
synchronized (connectionPools)
{
pool = (LinkedList)connectionPools.get(address);
if (pool == null)
{
pool = new LinkedList();
connectionPools.put(address, pool);
}
}
}
protected ClientSocket getConnection() throws Exception
{
Socket socket = null;
ClientSocket cs = null;
//
// Need to retry a few times
// on socket connection because, at least on Windoze,
// if too many concurrent threads try to connect
// at same time, you get ConnectionRefused
//
// Retrying seems to be the most performant.
//
// This problem always happens with RMI and seems to
// have nothing to do with backlog or number of threads
// waiting in accept() on the server.
//
for (int i = 0; i < retryCount; i++)
{
ClientSocket pooled = getPooledConnection();
if (pooled != null)
{
usedPooled++;
inUseCount ++;
return pooled;
}
try
{
if( trace)
{
log.trace("Connecting to addr: "+address.address
+", port: "+address.port
+",clientSocketFactory: "+address.clientSocketFactory
+",enableTcpNoDelay: "+address.enableTcpNoDelay
+",timeout: "+address.timeout);
}
if( address.clientSocketFactory != null )
socket = address.clientSocketFactory.createSocket(address.address, address.port);
else
socket = new Socket(address.address, address.port);
socketConnectCount ++;
if( trace )
log.trace("Connected, socket="+socket);
socket.setTcpNoDelay(address.enableTcpNoDelay);
cs = new ClientSocket(socket, address.timeout);
inUseCount ++;
if( trace )
{
log.trace("New ClientSocket: "+cs
+", usedPooled="+ usedPooled
+", inUseCount="+ inUseCount
+", socketConnectCount="+ socketConnectCount
+", socketCloseCount="+ socketCloseCount
);
}
break;
}
catch (Exception ex)
{
if( ex instanceof InterruptedIOException || ex instanceof SocketException )
{
if( trace )
log.trace("Connect failed", ex);
if (i + 1 < retryCount)
{
Thread.sleep(1);
continue;
}
}
throw ex;
}
}
// Should not happen
if( cs == null )
throw new ConnectException("Failed to obtain a socket, tries="+retryCount);
return cs;
}
protected ClientSocket firstConnection()
{
synchronized (pool)
{
if(pool.size() > 0)
return (ClientSocket)pool.removeFirst();
}
return null;
}
protected ClientSocket getPooledConnection()
{
ClientSocket socket = null;
while ((socket = firstConnection()) != null)
{
try
{
// Test to see if socket is alive by send ACK message
if( trace )
log.trace("Checking pooled socket: "+socket+", address: "+socket.socket.getLocalSocketAddress());
final byte ACK = 1;
socket.out.writeByte(ACK);
socket.out.flush();
socket.in.readByte();
if( trace )
{
log.trace("Using pooled ClientSocket: "+socket
+", usedPooled="+ usedPooled
+", inUseCount="+ inUseCount
+", socketConnectCount="+ socketConnectCount
+", socketCloseCount="+ socketCloseCount
);
}
return socket;
}
catch (Exception ex)
{
if( trace )
log.trace("Failed to validate pooled socket: "+socket, ex);
try
{
if( socket != null )
{
socketCloseCount --;
socket.socket.close();
}
}
catch (Exception ignored)
{
}
finally
{
if( socket != null )
socket.socket = null;
}
}
}
return null;
}
/**
* Return a socket to the pool
* @param socket
* @return true if socket was added to the pool, false if the pool
* was full
*/
protected boolean returnConnection(ClientSocket socket)
{
boolean pooled = false;
synchronized( pool )
{
if (pool.size() < maxPoolSize)
{
pool.add(socket);
inUseCount --;
pooled = true;
}
}
return pooled;
}
/**
* The name of of the server.
*/
public String getServerHostName() throws Exception
{
return address.address;
}
/**
* ???
*
* @todo MOVE TO TRANSACTION
*
* @return the transaction propagation context of the transaction
* associated with the current thread.
* Returns <code>null</code> if the transaction manager was never
* set, or if no transaction is associated with the current thread.
*/
public Object getTransactionPropagationContext()
throws SystemException
{
TransactionPropagationContextFactory tpcFactory = TransactionPropagationContextUtil.getTPCFactoryClientSide();
return (tpcFactory == null) ? null : tpcFactory.getTransactionPropagationContext();
}
/**
* The invocation on the delegate, calls the right invoker. Remote if we are remote,
* local if we are local.
*/
public Object invoke(Invocation invocation)
throws Exception
{
boolean trace = log.isTraceEnabled();
// We are going to go through a Remote invocation, switch to a Marshalled Invocation
PooledMarshalledInvocation mi = new PooledMarshalledInvocation(invocation);
// Set the transaction propagation context
// @todo: MOVE TO TRANSACTION
mi.setTransactionPropagationContext(getTransactionPropagationContext());
Object response = null;
long start = System.currentTimeMillis();
ClientSocket socket = getConnection();
long end = System.currentTimeMillis() - start;
getSocketTime += end;
// Add the socket session if it exists
if( socket.sessionID != null )
{
mi.setValue("SESSION_ID", socket.sessionID);
if( trace )
log.trace("Added SESSION_ID to invocation");
}
try
{
if( trace )
log.trace("Sending invocation to: "+mi.getObjectName());
socket.out.writeObject(mi);
socket.out.reset();
socket.out.writeObject(Boolean.TRUE); // for stupid ObjectInputStream reset
socket.out.flush();
socket.out.reset();
end = System.currentTimeMillis() - start;
writeTime += end;
start = System.currentTimeMillis();
response = socket.in.readObject();
// to make sure stream gets reset
// Stupid ObjectInputStream holds object graph
// can only be set by the client/server sending a TC_RESET
socket.in.readObject();
end = System.currentTimeMillis() - start;
readTime += end;
}
catch (Exception ex)
{
if( trace )
log.trace("Failure during invoke", ex);
try
{
socketCloseCount --;
socket.socket.close();
}
catch (Exception ignored) {}
finally
{
socket.socket = null;
}
throw new java.rmi.ConnectException("Failure during invoke", ex);
}
// Put socket back in pool for reuse
if( returnConnection(socket) == false )
{
// Failed, close the socket
if( trace )
log.trace("Closing unpooled socket: "+socket);
try
{
socketCloseCount --;
socket.socket.close();
}
catch (Exception ignored) {}
finally
{
socket.socket = null;
}
}
// Return response
try
{
if (response instanceof Exception)
{
throw ((Exception)response);
}
if (response instanceof MarshalledObject)
{
return ((MarshalledObject)response).get();
}
return response;
}
catch (ServerException ex)
{
// Suns RMI implementation wraps NoSuchObjectException in
// a ServerException. We cannot have that if we want
// to comply with the spec, so we unwrap here.
if (ex.detail instanceof NoSuchObjectException)
{
throw (NoSuchObjectException) ex.detail;
}
//likewise
if (ex.detail instanceof TransactionRolledbackException)
{
throw (TransactionRolledbackException) ex.detail;
}
throw ex;
}
}
/**
* Write out the serializable data
* @serialData address ServerAddress
* @serialData maxPoolSize int
* @serialData WIRE_VERSION int version
* @serialData retryCount int
* @param out
* @throws IOException
*/
public void writeExternal(final ObjectOutput out)
throws IOException
{
// The legacy wire format is address, maxPoolSize
out.writeObject(address);
out.writeInt(maxPoolSize);
// Write out the current version format and its data
out.writeInt(WIRE_VERSION);
out.writeInt(retryCount);
}
public void readExternal(final ObjectInput in)
throws IOException, ClassNotFoundException
{
trace = log.isTraceEnabled();
address = (ServerAddress)in.readObject();
maxPoolSize = in.readInt();
int version = 0;
try
{
version = in.readInt();
}
catch(EOFException e)
{
// No version written and there is no more data
}
catch(OptionalDataException e)
{
// No version written and there is data from other objects
}
switch( version )
{
case 0:
// This has no retryCount, default it to the hard-coded value
retryCount = MAX_RETRIES;
break;
case 1:
readVersion1(in);
break;
default:
/* Assume a newer version that only adds defaultable values.
The alternative would be to thrown an exception
*/
break;
}
initPool();
}
private void readVersion1(final ObjectInput in)
throws IOException
{
retryCount = in.readInt();
}
}