package org.infinispan.client.hotrod.impl.transport.tcp;
import static org.infinispan.commons.io.SignedNumeric.writeSignedInt;
import static org.infinispan.commons.io.UnsignedNumeric.readUnsignedInt;
import static org.infinispan.commons.io.UnsignedNumeric.readUnsignedLong;
import static org.infinispan.commons.io.UnsignedNumeric.writeUnsignedInt;
import static org.infinispan.commons.io.UnsignedNumeric.writeUnsignedLong;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.net.SocketAddress;
import java.nio.channels.SocketChannel;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicLong;
import javax.net.ssl.SNIHostName;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.security.sasl.SaslClient;
import org.infinispan.client.hotrod.exceptions.TransportException;
import org.infinispan.client.hotrod.impl.transport.AbstractTransport;
import org.infinispan.client.hotrod.impl.transport.TransportFactory;
import org.infinispan.client.hotrod.logging.Log;
import org.infinispan.client.hotrod.logging.LogFactory;
import org.infinispan.commons.util.Util;
/**
* Transport implementation based on TCP.
*
* @author Mircea.Markus@jboss.com
* @since 4.1
*/
public class TcpTransport extends AbstractTransport {
public static final int SOCKET_STREAM_BUFFER = 8 * 1024;
//needed for debugging
private static AtomicLong ID_COUNTER = new AtomicLong(0);
private static final Log log = LogFactory.getLog(TcpTransport.class, Log.class);
private static final boolean trace = log.isTraceEnabled();
private final Socket socket;
private final SocketChannel socketChannel;
private InputStream socketInputStream;
private OutputStream socketOutputStream;
private final SocketAddress serverAddress;
private final long id = ID_COUNTER.incrementAndGet();
private volatile boolean invalid;
private SaslClient saslClient;
public TcpTransport(SocketAddress serverAddress, TransportFactory transportFactory) {
super(transportFactory);
this.serverAddress = serverAddress;
try {
if (transportFactory.getSSLContext() != null) {
socketChannel = null; // We don't use a SocketChannel in the SSL case
SSLContext sslContext = transportFactory.getSSLContext();
SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory();
socket = sslSocketFactory.createSocket();
setSniHostName(transportFactory.getSniHostName());
} else {
socketChannel = SocketChannel.open();
socket = socketChannel.socket();
}
socket.connect(serverAddress, transportFactory.getConnectTimeout());
socket.setTcpNoDelay(transportFactory.isTcpNoDelay());
socket.setKeepAlive(transportFactory.isTcpKeepAlive());
socket.setSoTimeout(transportFactory.getSoTimeout());
socketInputStream = new BufferedInputStream(socket.getInputStream(), SOCKET_STREAM_BUFFER);
// ensure we don't send a packet for every output byte
socketOutputStream = new BufferedOutputStream(socket.getOutputStream(), SOCKET_STREAM_BUFFER);
} catch (Exception e) {
String message = String.format("Could not connect to server: %s", serverAddress);
log.tracef(e, "Could not connect to server: %s", serverAddress);
throw new TransportException(message, e, serverAddress);
}
}
private void setSniHostName(String sniHostName) {
if(sniHostName != null) {
SSLSocket sslSocket = (SSLSocket) this.socket;
SSLParameters sslParameters = sslSocket.getSSLParameters();
sslParameters.setServerNames(Arrays.asList(new SNIHostName(sniHostName)));
sslSocket.setSSLParameters(sslParameters);
}
}
void setSaslClient(SaslClient saslClient) {
this.saslClient = saslClient;
try {
this.socketInputStream = new SaslInputStream(socket.getInputStream(), saslClient);
this.socketOutputStream = new SaslOutputStream(socket.getOutputStream(), saslClient);
} catch (IOException e) {
invalid = true;
throw new TransportException(e, serverAddress);
}
}
@Override
public void writeVInt(int vInt) {
try {
writeUnsignedInt(socketOutputStream, vInt);
} catch (IOException e) {
invalid = true;
throw new TransportException(e, serverAddress);
}
}
@Override
public void writeSignedVInt(int vInt) {
try {
writeSignedInt(socketOutputStream, vInt);
} catch (IOException e) {
invalid = true;
throw new TransportException(e, serverAddress);
}
}
@Override
public void writeVLong(long l) {
try {
writeUnsignedLong(socketOutputStream, l);
} catch (IOException e) {
invalid = true;
throw new TransportException(e, serverAddress);
}
}
@Override
public long readVLong() {
try {
return readUnsignedLong(socketInputStream);
} catch (IOException e) {
invalid = true;
throw new TransportException(e, serverAddress);
}
}
@Override
public int readVInt() {
try {
return readUnsignedInt(socketInputStream);
} catch (IOException e) {
invalid = true;
throw new TransportException(e, serverAddress);
}
}
@Override
protected void writeBytes(byte[] toAppend) {
try {
socketOutputStream.write(toAppend);
if (trace) {
log.tracef("Wrote %d bytes", toAppend.length);
}
} catch (IOException e) {
invalid = true;
throw new TransportException(
"Problems writing data to stream", e, serverAddress);
}
}
@Override
protected void writeBytes(byte[] toAppend, int offset, int count) {
try {
socketOutputStream.write(toAppend, offset, count);
if (trace) {
log.tracef("Wrote %d bytes", toAppend.length);
}
} catch (IOException e) {
invalid = true;
throw new TransportException(
"Problems writing data to stream", e, serverAddress);
}
}
@Override
public void writeByte(short toWrite) {
try {
socketOutputStream.write(toWrite);
if (trace) {
log.tracef("Wrote byte %d", toWrite);
}
} catch (IOException e) {
invalid = true;
throw new TransportException(
"Problems writing data to stream", e, serverAddress);
}
}
@Override
public void flush() {
try {
socketOutputStream.flush();
if (trace) {
log.tracef("Flushed socket: %s", socket);
}
} catch (IOException e) {
invalid = true;
throw new TransportException(e, serverAddress);
}
}
@Override
public short readByte() {
int resultInt;
try {
resultInt = socketInputStream.read();
if (trace)
log.tracef("Read byte %d from socket input in %s", resultInt, socket);
} catch (IOException e) {
invalid = true;
throw new TransportException(e, serverAddress);
}
if (resultInt == -1) {
throw new TransportException("End of stream reached!", serverAddress);
}
return (short) resultInt;
}
@Override
public void release() {
destroy();
}
@Override
public void readByteArray(byte[] result, int size) {
boolean done = false;
int offset = 0;
do {
int read;
try {
int len = size - offset;
if (trace) {
log.tracef("Offset: %d, len=%d, size=%d", offset, len, size);
}
read = socketInputStream.read(result, offset, len);
} catch (IOException e) {
invalid = true;
throw new TransportException(e, serverAddress);
}
if (read == -1) {
throw new RuntimeException("End of stream reached!");
}
if (read + offset == size) {
done = true;
} else {
offset += read;
if (offset > result.length) {
throw new IllegalStateException("Assertion!");
}
}
} while (!done);
if (trace) {
log.tracef("Successfully read array with size: %d", size);
}
}
@Override
public byte[] readByteArray(final int size) {
byte[] result = new byte[size];
readByteArray(result, size);
return result;
}
public SocketAddress getServerAddress() {
return serverAddress;
}
@Override
public String toString() {
return "TcpTransport{" +
"socket=" + socket +
", serverAddress=" + serverAddress +
", id =" + id +
"} ";
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
TcpTransport that = (TcpTransport) o;
if (serverAddress != null ? !serverAddress.equals(that.serverAddress) : that.serverAddress != null) {
return false;
}
if (socket != null ? !socket.equals(that.socket) : that.socket != null) {
return false;
}
return true;
}
@Override
public int hashCode() {
int result = socket != null ? socket.hashCode() : 0;
result = 31 * result + (serverAddress != null ? serverAddress.hashCode() : 0);
return result;
}
public void destroy() {
try {
if (socketInputStream != null) socketInputStream.close();
if (socketOutputStream != null) socketOutputStream.close();
if (socketChannel != null) socketChannel.close();
if (socket != null) socket.close();
if (saslClient != null) saslClient.dispose();
if (trace) {
log.tracef("Successfully closed socket: %s", socket);
}
} catch (IOException e) {
invalid = true;
log.errorClosingSocket(this, e);
// Just in case an exception is thrown, make sure they're fully closed
Util.close(socketInputStream, socketOutputStream, socketChannel);
Util.close(socket);
}
}
@Override
public boolean isValid() {
return !socket.isClosed() && !invalid;
}
public long getId() {
return id;
}
@Override
public byte[] dumpStream() {
ByteArrayOutputStream os = new ByteArrayOutputStream();
try {
socket.setSoTimeout(5000);
// Read 32kb at most
for (int i = 0; i < 32768; i++) {
int b = socketInputStream.read();
if (b < 0) {
break;
}
os.write(b);
}
} catch (IOException e) {
// Ignore
} finally {
Util.close(socket);
}
return os.toByteArray();
}
@Override
public SocketAddress getRemoteSocketAddress() {
return socket.getRemoteSocketAddress();
}
@Override
public void invalidate() {
invalid = true;
}
}