package org.jboss.narayana.blacktie.jatmibroker.core.transport.hybrid.stomp;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.log4j.LogManager;
import org.apache.log4j.Logger;
import org.jboss.narayana.blacktie.jatmibroker.xatmi.Connection;
import org.jboss.narayana.blacktie.jatmibroker.xatmi.ConnectionException;
/**
* This class could be extended to support connection reconnection.
*/
public class StompManagement {
private static final Logger log = LogManager.getLogger(StompManagement.class);
private static List<Socket> disconnectedConnections = new ArrayList<Socket>();
private static final byte[] COLON = {':'};
private static final byte[] EOL = {'\n'};
private static final byte[] EOM = {'\0', '\n', '\n'};
public static void close(Socket socket, OutputStream outputStream, InputStream inputStream) throws IOException {
log.debug("close");
Message message = new Message();
message.setCommand("DISCONNECT");
Map<String, String> headers = new HashMap<String, String>();
headers.put("receipt", "disconnect");
message.setHeaders(headers);
send(message, outputStream);
log.debug("Sent disconnect");
synchronized (socket) {
if (!disconnectedConnections.remove(socket)) {
Message received = receive(socket, inputStream);
if (received != null && received.getCommand().equals("ERROR")) {
log.error("Did not receive the receipt for the disconnect:" + new String(received.getBody()));
}
}
disconnectedConnections.remove(socket);
}
}
public static Socket connect(String host, int port, String username, String password) throws IOException,
ConnectionException {
Socket socket = new Socket(host, port);
InputStream inputStream = socket.getInputStream();
OutputStream outputStream = socket.getOutputStream();
Map<String, String> headers = new HashMap<String, String>();
headers.put("login", username);
headers.put("passcode", password);
Message message = new Message();
message.setCommand("CONNECT");
message.setHeaders(headers);
send(message, outputStream);
Message received = receive(socket, inputStream);
if (received.getCommand().equals("ERROR")) {
throw new ConnectionException(Connection.TPESYSTEM, new String(received.getBody()));
}
log.debug("Created socket: " + socket + " input: " + inputStream + "output: " + outputStream);
return socket;
}
public static void send(Message message, OutputStream outputStream) throws IOException {
log.trace("Writing on: " + outputStream);
synchronized (outputStream) {
outputStream.write(message.getCommand().getBytes());
outputStream.write(EOL);
for (Map.Entry<String, String> header : message.getHeaders().entrySet()) {
outputStream.write(header.getKey().getBytes());
outputStream.write(COLON);
outputStream.write(header.getValue().getBytes());
outputStream.write(EOL);
}
outputStream.write(EOL);
if(message.getBody() != null) {
outputStream.write(message.getBody());
}
outputStream.write(EOM);
}
log.trace("Wrote on: " + outputStream);
}
public static Message receive(Socket socket, InputStream inputStream) throws IOException {
synchronized (socket) {
log.trace("Reading from: " + inputStream);
Message message = new Message();
message.setCommand(readLine(inputStream));
log.trace(message.getCommand());
Map<String, String> headers = new HashMap<String, String>();
String header;
while ((header = readLine(inputStream)).length() > 0) {
int sep = header.indexOf(':');
if (sep > 0) {
String key = header.substring(0, sep);
String value = header.substring(sep + 1, header.length());
headers.put(key.trim(), value.trim());
log.trace("Header: " + key + ":" + value);
}
}
message.setHeaders(headers);
if (message.getCommand() != null) {
if (message.getCommand().equals("RECEIPT")) {
if (message.getHeaders().get("receipt-id") != null
&& message.getHeaders().get("receipt-id").equals("disconnect")) {
log.debug("Read disconnect receipt from: " + inputStream);
disconnectedConnections.add(socket);
message = null;
} else {
log.trace("Read from: " + inputStream + " command was: " + message.getCommand());
}
readLine(inputStream);
readLine(inputStream);
} else if (!message.getCommand().equals("ERROR")) {
String contentLength = headers.get("content-length");
if (contentLength != null) {
byte[] body = new byte[Integer.valueOf(contentLength)];
int offset = 0;
while (offset != body.length) {
offset = inputStream.read(body, offset, body.length - offset);
}
message.setBody(body);
log.trace("Read error: " + body);
}
readLine(inputStream);
readLine(inputStream);
log.trace("Read from: " + inputStream + " command was: " + message.getCommand());
} else {
message.setBody(headers.get("message").getBytes());
// Drain off the error message
String read = null;
do {
read = readLine(inputStream);
if (read != null)
log.debug(read);
} while (read != null);
readLine(inputStream);
log.trace("Read from: " + inputStream + " command was: " + message.getCommand());
}
} else {
log.trace("Read from: " + inputStream + " null");
message = null;
}
return message;
}
}
private static String readLine(InputStream inputStream) throws IOException {
String toReturn = null;
char[] read = new char[0];
char c = (char) inputStream.read();
while (c != '\n' && c != '\000' && c != -1) {
char[] tmp = new char[read.length + 1];
System.arraycopy(read, 0, tmp, 0, read.length);
tmp[read.length] = c;
read = tmp;
c = (char) inputStream.read();
}
if (c == -1) {
throw new EOFException("Read the end of the stream");
}
if (c == '\000') {
log.trace("returning null");
} else {
toReturn = new String(read);
log.trace("returning: " + toReturn);
}
return toReturn;
}
}