package org.threadly.litesockets;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SocketChannel;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLSession;
import org.threadly.concurrent.future.ListenableFuture;
import org.threadly.concurrent.future.SettableListenableFuture;
import org.threadly.litesockets.buffers.ReuseableMergedByteBuffers;
import org.threadly.litesockets.utils.SSLProcessor;
import org.threadly.util.ArgumentVerifier;
import org.threadly.util.Clock;
import org.threadly.util.Pair;
/**
* A Simple TCP client.
*
*/
@SuppressWarnings("deprecation")
public class TCPClient extends Client {
protected static final int DEFAULT_SOCKET_TIMEOUT = 10000;
protected static final int MIN_WRITE_BUFFER_SIZE = 8192;
protected static final int MAX_COMBINED_WRITE_BUFFER_SIZE = 65536;
private final ReuseableMergedByteBuffers writeBuffers = new ReuseableMergedByteBuffers();
private final Deque<Pair<Long, SettableListenableFuture<Long>>> writeFutures = new ArrayDeque<Pair<Long, SettableListenableFuture<Long>>>();
private final TCPSocketOptions tso = new TCPSocketOptions();
protected final AtomicBoolean startedConnection = new AtomicBoolean(false);
protected final SettableListenableFuture<Boolean> connectionFuture = new SettableListenableFuture<Boolean>(false);
protected final SocketChannel channel;
protected final InetSocketAddress remoteAddress;
private volatile ByteBuffer currentWriteBuffer = ByteBuffer.allocate(0);
private volatile SSLProcessor sslProcessor;
protected volatile int maxConnectionTime = DEFAULT_SOCKET_TIMEOUT;
protected volatile long connectExpiresAt = -1;
/**
* This creates TCPClient with a connection to the specified port and IP. This connection is not is not
* yet made {@link #connect()} must be called which will do the actual connect.
*
* @param sei The {@link SocketExecuter} implementation this client will use.
* @param host The hostname or IP address to connect this client too.
* @param port The port to connect this client too.
* @throws IOException - This is thrown if there are any problems making the socket.
*/
protected TCPClient(final SocketExecuter sei, final String host, final int port) throws IOException {
super(sei);
remoteAddress = new InetSocketAddress(host, port);
channel = SocketChannel.open();
channel.configureBlocking(false);
}
/**
* <p>This creates a TCPClient based off an already existing {@link SocketChannel}.
* This {@link SocketChannel} must already be connected.</p>
*
* @param sei the {@link SocketExecuter} to use for this client.
* @param channel the {@link SocketChannel} to use for this client.
* @throws IOException if there is anything wrong with the {@link SocketChannel} this will be thrown.
*/
protected TCPClient(final SocketExecuter sei, final SocketChannel channel) throws IOException {
super(sei);
if(! channel.isOpen()) {
throw new ClosedChannelException();
}
connectionFuture.setResult(true);
if(channel.isBlocking()) {
channel.configureBlocking(false);
}
this.channel = channel;
remoteAddress = (InetSocketAddress) channel.socket().getRemoteSocketAddress();
startedConnection.set(true);
}
@Override
public void setConnectionTimeout(final int timeout) {
ArgumentVerifier.assertGreaterThanZero(timeout, "Timeout");
this.maxConnectionTime = timeout;
}
@Override
public ListenableFuture<Boolean> connect(){
if(startedConnection.compareAndSet(false, true)) {
try {
channel.connect(remoteAddress);
connectExpiresAt = maxConnectionTime + Clock.accurateForwardProgressingMillis();
se.setClientOperations(this);
se.watchFuture(connectionFuture, maxConnectionTime);
} catch (Exception e) {
connectionFuture.setFailure(e);
close();
}
}
return connectionFuture;
}
@Override
protected void setConnectionStatus(final Throwable t) {
if(t == null) {
connectionFuture.setResult(true);
} else {
if(connectionFuture.setFailure(t)) {
close();
}
}
}
@Override
public boolean hasConnectionTimedOut() {
if(! startedConnection.get() || channel.isConnected()) {
return false;
}
return Clock.accurateForwardProgressingMillis() > connectExpiresAt;
}
@Override
public int getTimeout() {
return maxConnectionTime;
}
@Override
protected SocketChannel getChannel() {
return channel;
}
@Override
protected Socket getSocket() {
return channel.socket();
}
@Override
public void close() {
if(setClose()) {
se.setClientOperations(this);
final ClosedChannelException cce = new ClosedChannelException();
this.getClientsThreadExecutor().execute(new Runnable() {
@Override
public void run() {
synchronized(writerLock) {
for(final Pair<Long, SettableListenableFuture<Long>> p: writeFutures) {
p.getRight().setFailure(cce);
}
writeFutures.clear();
writeBuffers.discard(writeBuffers.remaining());
}
}});
this.callClosers();
}
}
@Override
public WireProtocol getProtocol() {
return WireProtocol.TCP;
}
@Override
public boolean canWrite() {
return writeBuffers.remaining() > 0 ;
}
@Override
public int getWriteBufferSize() {
return this.writeBuffers.remaining();
}
@Override
public int getMaxBufferSize() {
return this.maxBufferSize;
}
@Override
public void setMaxBufferSize(final int size) {
ArgumentVerifier.assertNotNegative(size, "size");
maxBufferSize = size;
if(channel.isConnected()) {
this.se.setClientOperations(this);
}
}
@Override
public ReuseableMergedByteBuffers getRead() {
ReuseableMergedByteBuffers mbb = super.getRead();
if(sslProcessor != null && sslProcessor.handShakeStarted() && mbb.remaining() > 0) {
mbb = sslProcessor.decrypt(mbb);
}
return mbb;
}
@Override
public ListenableFuture<?> write(final ByteBuffer bb) {
if(isClosed()) {
throw new IllegalStateException("Cannot write to closed client!");
}
synchronized(writerLock) {
final boolean needNotify = ! canWrite();
final SettableListenableFuture<Long> slf = new SettableListenableFuture<Long>(false);
if(sslProcessor != null && sslProcessor.handShakeStarted()) {
writeBuffers.add(sslProcessor.encrypt(bb));
} else {
writeBuffers.add(bb);
}
writeFutures.add(new Pair<Long, SettableListenableFuture<Long>>(writeBuffers.getTotalConsumedBytes()+writeBuffers.remaining(), slf));
if(needNotify && se != null && channel.isConnected()) {
se.setClientOperations(this);
}
return slf;
}
}
@Override
protected ByteBuffer getWriteBuffer() {
if(currentWriteBuffer.remaining() != 0) {
return currentWriteBuffer;
}
synchronized(writerLock) {
//This is to keep from doing a ton of little writes if we can. We will try to
//do at least 8k at a time, and up to 65k if we are already having to combine buffers
if(writeBuffers.nextPopSize() < MIN_WRITE_BUFFER_SIZE && writeBuffers.remaining() > writeBuffers.nextPopSize()) {
if(writeBuffers.remaining() < MAX_COMBINED_WRITE_BUFFER_SIZE) {
currentWriteBuffer = writeBuffers.pull(writeBuffers.remaining());
} else {
currentWriteBuffer = writeBuffers.pull(MAX_COMBINED_WRITE_BUFFER_SIZE);
}
} else {
currentWriteBuffer = writeBuffers.pop();
}
}
return currentWriteBuffer;
}
@Override
protected void reduceWrite(final int size) {
synchronized(writerLock) {
addWriteStats(size);
if(currentWriteBuffer.remaining() == 0) {
while(this.writeFutures.peekFirst() != null && writeFutures.peekFirst().getLeft() <= writeBuffers.getTotalConsumedBytes()) {
final Pair<Long, SettableListenableFuture<Long>> p = writeFutures.pollFirst();
p.getRight().setResult(p.getLeft());
}
}
}
}
@Override
public InetSocketAddress getRemoteSocketAddress() {
return remoteAddress;
}
@Override
public InetSocketAddress getLocalSocketAddress() {
return (InetSocketAddress) channel.socket().getLocalSocketAddress();
}
@Override
public String toString() {
return "TCPClient:FROM:"+getLocalSocketAddress()+":TO:"+getRemoteSocketAddress();
}
@Override
public boolean setSocketOption(final SocketOption so, final int value) {
try{
switch(so) {
case TCP_NODELAY: {
return tso.setTcpNoDelay(value > 0);
}
case SEND_BUFFER_SIZE: {
return tso.setSocketSendBuffer(value);
}
case RECV_BUFFER_SIZE: {
return tso.setSocketRecvBuffer(value);
}
case USE_NATIVE_BUFFERS: {
return tso.setNativeBuffers(value > 0);
}
default:
return false;
}
} catch(Exception e) {
}
return false;
}
@Override
public ClientOptions clientOptions() {
return tso;
}
public void setSSLEngine(final SSLEngine ssle) {
sslProcessor = new SSLProcessor(this, ssle);
}
public boolean isEncrypted() {
if(sslProcessor == null) {
return false;
}
return sslProcessor.isEncrypted();
}
public ListenableFuture<SSLSession> startSSL() {
if(sslProcessor != null) {
return sslProcessor.doHandShake();
}
throw new IllegalStateException("Must Set the SSLEngine before starting Encryption!");
}
/**
*
* @author lwahlmeier
*
*/
private class TCPSocketOptions extends BaseClientOptions {
@Override
public boolean setTcpNoDelay(boolean enabled) {
try {
channel.socket().setTcpNoDelay(enabled);
return true;
} catch (SocketException e) {
return false;
}
}
@Override
public boolean getTcpNoDelay() {
try {
return channel.socket().getTcpNoDelay();
} catch (SocketException e) {
return false;
}
}
@Override
public boolean setSocketSendBuffer(int size) {
try {
ArgumentVerifier.assertGreaterThanZero(size, "size");
int prev = channel.socket().getSendBufferSize();
channel.socket().setSendBufferSize(size);
if(channel.socket().getSendBufferSize() != size) {
channel.socket().setSendBufferSize(prev);
return false;
}
return true;
} catch (Exception e) {
return false;
}
}
@Override
public int getSocketSendBuffer() {
try {
return channel.socket().getSendBufferSize();
} catch (SocketException e) {
return -1;
}
}
@Override
public boolean setSocketRecvBuffer(int size) {
try {
ArgumentVerifier.assertGreaterThanZero(size, "size");
int prev = channel.socket().getReceiveBufferSize();
channel.socket().setReceiveBufferSize(size);
if(channel.socket().getReceiveBufferSize() != size) {
channel.socket().setReceiveBufferSize(prev);
return false;
}
return true;
} catch (Exception e) {
return false;
}
}
@Override
public int getSocketRecvBuffer() {
try {
return channel.socket().getReceiveBufferSize();
} catch (SocketException e) {
return -1;
}
}
}
}