package org.jdiameter.client.impl.transport.tcp;
import org.jdiameter.api.AvpDataException;
import org.jdiameter.client.api.io.NotInitializedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.BufferOverflowException;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousCloseException;
import java.nio.channels.ClosedByInterruptException;
import java.nio.channels.SocketChannel;
import java.nio.channels.spi.SelectorProvider;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
class TCPTransportClient implements Runnable {
private TCPClientConnection parentConnection;
public static final int DEFAULT_BUFFER_SIZE = 1024;
public static final int DEFAULT_STORAGE_SIZE = 2048;
protected boolean stop = false;
protected Thread selfThread;
protected int bufferSize = DEFAULT_BUFFER_SIZE;
protected ByteBuffer buffer = ByteBuffer.allocate(this.bufferSize);
protected InetSocketAddress destAddress;
protected InetSocketAddress origAddress;
protected SocketChannel socketChannel;
protected Lock lock = new ReentrantLock();
protected int storageSize = DEFAULT_STORAGE_SIZE;
protected ByteBuffer storage = ByteBuffer.allocate(storageSize);
protected Logger logger = LoggerFactory.getLogger(TCPTransportClient.class);
TCPTransportClient() {
}
/**
* Default constructor
*
* @param parenConnection connection created this transport
*/
TCPTransportClient(TCPClientConnection parenConnection) {
this.parentConnection = parenConnection;
}
/** Network init socket */
public void initialize() throws IOException, NotInitializedException {
if (destAddress == null) {
throw new NotInitializedException("Destination address is not set");
}
socketChannel = SelectorProvider.provider().openSocketChannel();
if (origAddress != null) {
socketChannel.socket().bind(origAddress);
}
socketChannel.connect(destAddress);
socketChannel.configureBlocking(true);
getParent().onConnected();
}
public TCPClientConnection getParent() {
return parentConnection;
}
public void initialize(Socket socket) throws IOException, NotInitializedException {
socketChannel = socket.getChannel();
socketChannel.configureBlocking(true);
destAddress = new InetSocketAddress(socket.getInetAddress(), socket.getPort());
}
public void start() throws Exception {
logger.debug("Starting transport");
if (socketChannel == null) {
throw new NotInitializedException("Transport is not initialized");
}
if (!socketChannel.isConnected()) {
throw new NotInitializedException("Socket channel is not connected");
}
if (getParent() == null) {
throw new NotInitializedException("No parent connection is set is set");
}
if (selfThread == null || !selfThread.isAlive()) {
selfThread = new Thread(this); // TODO
}
if (!selfThread.isAlive()) {
selfThread.start();
}
}
public void run() {
logger.debug("Transport is started");
try {
while (!stop) {
int dataLength = socketChannel.read(buffer);
if (dataLength == -1) {
break;
}
buffer.flip();
byte[] data = new byte[buffer.limit()];
buffer.get(data);
append(data);
buffer.clear();
}
}
catch (ClosedByInterruptException e) {
logger.debug("Transport exception ", e);
} catch (AsynchronousCloseException e) {
logger.debug("Transport exception ", e);
} catch (Throwable e) {
logger.debug("Transport exception ", e);
} finally {
try {
clearBuffer();
if (socketChannel != null && socketChannel.isOpen()) {
socketChannel.close();
}
getParent().onDisconnect();
}
catch (Exception e) {
logger.debug("Error", e);
}
stop = false;
logger.info("Read thread is stopped");
}
}
public void stop() throws Exception {
logger.debug("Stopping transport");
stop = true;
if (socketChannel != null && socketChannel.isOpen()) {
socketChannel.close();
}
if (selfThread != null) {
selfThread.join(100);
}
clearBuffer();
logger.debug("Transport is stopped");
}
public void release() throws Exception {
stop();
destAddress = null;
}
private void clearBuffer() throws IOException {
bufferSize = DEFAULT_BUFFER_SIZE;
buffer = ByteBuffer.allocate(bufferSize);
}
public InetSocketAddress getDestAddress() {
return destAddress;
}
public void setDestAddress(InetSocketAddress address) {
destAddress = address;
logger.debug("Destination address is set to {} : {}",destAddress.getHostName(), destAddress.getPort());
}
public void setOrigAddress(InetSocketAddress address) {
origAddress = address;
}
public void sendMessage(ByteBuffer bytes) throws IOException {
int rc;
lock.lock();
try {
rc = socketChannel.write(bytes);
}
catch (Exception e) {
logger.debug("Can not send message", e);
throw new IOException("Error while sending message: " + e);
}
finally {
lock.unlock();
}
if (rc == -1) {
throw new IOException("Connection closed");
}
}
public String toString() {
StringBuffer buffer = new StringBuffer();
buffer.append("Transport to ");
if (this.destAddress != null) {
buffer.append(this.destAddress.getHostName());
buffer.append(":");
buffer.append(this.destAddress.getPort());
} else {
buffer.append("null");
}
buffer.append("@");
buffer.append(super.toString());
return buffer.toString();
}
boolean isConnected() {
return socketChannel != null && socketChannel.isConnected();
}
/**
* Adds data to storage
*
* @param data data to add
*/
void append(byte[] data) {
if (storage.position() + data.length >= storage.capacity()) {
ByteBuffer tmp = ByteBuffer.allocate(storage.limit() + data.length * 2);
byte[] tmpData = new byte[storage.position()];
storage.flip();
storage.get(tmpData);
tmp.put(tmpData);
storage = tmp;
logger.warn("Increase storage size. Current size is {}", storage.array().length);
}
try {
storage.put(data);
}
catch (BufferOverflowException boe) {
logger.error("Buffer overflow occured", boe);
}
boolean messageReseived;
do {
messageReseived = seekMessage(storage);
} while (messageReseived);
}
private boolean seekMessage(ByteBuffer localStorage) {
if (storage.position() == 0) {
return false;
}
storage.flip();
int tmp = localStorage.getInt();
localStorage.position(0);
byte vers = (byte) (tmp >> 24);
if (vers != 1) {
return false;
}
int dataLength = (tmp & 0xFFFFFF);
if (localStorage.limit() < dataLength) {
localStorage.position(localStorage.limit());
localStorage.limit(localStorage.capacity());
return false;
}
byte[] data = new byte[dataLength];
localStorage.get(data);
localStorage.position(dataLength);
localStorage.compact();
try {
getParent().onMessageReveived(ByteBuffer.wrap(data));
}
catch (AvpDataException e) {
logger.debug("Garbage was received from server");
storage.clear();
getParent().onAvpDataException(e);
}
return true;
}
}