package org.jgroups.protocols; import org.jgroups.*; import org.jgroups.annotations.*; import org.jgroups.stack.Protocol; import org.jgroups.util.UUID; import org.jgroups.util.Util; import java.io.*; import java.net.*; import java.nio.ByteBuffer; import java.util.*; import java.util.concurrent.ConcurrentMap; /** * Protocol which provides STOMP (http://stomp.codehaus.org/) support. Very simple implementation, with a * one-thread-per-connection model. Use for a few hundred clients max.<p/> * The intended use for this protocol is pub-sub with clients which handle text messages, e.g. stock updates, * SMS messages to mobile clients, SNMP traps etc.<p/> * Note that the full STOMP protocol has not yet been implemented, e.g. transactions are not supported. * todo: use a thread pool to handle incoming frames and to send messages to clients * <p/> * todo: add PING to test health of client connections * <p/> * @author Bela Ban * @since 2.11 */ @MBean(description="Server side STOPM protocol, STOMP clients can connect to it") @Experimental public class STOMP extends Protocol implements Runnable { /* ----------------------------------------- Properties ----------------------------------------------- */ @LocalAddress @Property(name="bind_addr", description="The bind address which should be used by the server socket. The following special values " + "are also recognized: GLOBAL, SITE_LOCAL, LINK_LOCAL and NON_LOOPBACK", defaultValueIPv4="0.0.0.0", defaultValueIPv6="::", systemProperty={Global.STOMP_BIND_ADDR},writable=false) protected InetAddress bind_addr=null; @Property(description="If set, then endpoint will be set to this address",systemProperty=Global.STOMP_ENDPOINT_ADDR) protected String endpoint_addr; @Property(description="Port on which the STOMP protocol listens for requests",writable=false) protected int port=8787; @Property(description="If set to false, then a destination of /a/b match /a/b/c, a/b/d, a/b/c/d etc") protected boolean exact_destination_match=true; @Property(description="If true, information such as a list of endpoints, or views, will be sent to all clients " + "(via the INFO command). This allows for example intelligent clients to connect to " + "a different server should a connection be closed.") protected boolean send_info=true; @Property(description="Forward received messages which don't have a StompHeader to clients") protected boolean forward_non_client_generated_msgs=false; /* --------------------------------------------- JMX ---------------------------------------------------*/ @ManagedAttribute(description="Number of client connections",writable=false) public int getNumConnections() {return connections.size();} @ManagedAttribute(description="Number of subscriptions",writable=false) public int getNumSubscriptions() {return subscriptions.size();} @ManagedAttribute(description="Print subscriptions",writable=false) public String getSubscriptions() {return subscriptions.keySet().toString();} @ManagedAttribute public String getEndpoints() {return endpoints.toString();} /* --------------------------------------------- Fields ------------------------------------------------------ */ protected Address local_addr; protected ServerSocket srv_sock; @ManagedAttribute(writable=false) protected String endpoint; protected Thread acceptor; protected final List<Connection> connections=new LinkedList<Connection>(); protected final Map<Address,String> endpoints=new HashMap<Address,String>(); protected View view; // Subscriptions and connections which are subscribed protected final ConcurrentMap<String,Set<Connection>> subscriptions=Util.createConcurrentMap(20); public static enum ClientVerb {CONNECT, SEND, SUBSCRIBE, UNSUBSCRIBE, BEGIN, COMMIT, ABORT, ACK, DISCONNECT} public static enum ServerVerb {MESSAGE, RECEIPT, ERROR, CONNECTED, INFO} public static final byte NULL_BYTE=0; public STOMP() { } public void start() throws Exception { super.start(); srv_sock=Util.createServerSocket(getSocketFactory(), "jgroups.stomp.srv_sock", bind_addr, port); if(log.isDebugEnabled()) log.debug("server socket listening on " + srv_sock.getLocalSocketAddress()); if(acceptor == null) { acceptor=getThreadFactory().newThread(this, "STOMP acceptor"); acceptor.setDaemon(true); acceptor.start(); } endpoint=endpoint_addr != null? endpoint_addr : getAddress(); } public void stop() { if(log.isDebugEnabled()) log.debug("closing server socket " + srv_sock.getLocalSocketAddress()); if(acceptor != null && acceptor.isAlive()) { try { // this will terminate thread, peer will receive SocketException (socket close) getSocketFactory().close(srv_sock); } catch(Exception ex) { } } synchronized(connections) { for(Connection conn: connections) conn.stop(); connections.clear(); } acceptor=null; super.stop(); } // Acceptor loop public void run() { Socket client_sock; while(acceptor != null && srv_sock != null) { try { client_sock=srv_sock.accept(); if(log.isTraceEnabled()) log.trace("accepted connection from " + client_sock.getInetAddress() + ':' + client_sock.getPort()); Connection conn=new Connection(client_sock); Thread thread=getThreadFactory().newThread(conn, "STOMP client connection"); thread.setDaemon(true); synchronized(connections) { connections.add(conn); } thread.start(); conn.sendInfo(); } catch(IOException io_ex) { break; } } acceptor=null; } public Object down(Event evt) { switch(evt.getType()) { case Event.VIEW_CHANGE: handleView((View)evt.getArg()); break; case Event.SET_LOCAL_ADDRESS: local_addr=(Address)evt.getArg(); break; } return down_prot.down(evt); } public Object up(Event evt) { switch(evt.getType()) { case Event.MSG: Message msg=(Message)evt.getArg(); StompHeader hdr=(StompHeader)msg.getHeader(id); if(hdr == null) { if(forward_non_client_generated_msgs) { HashMap<String, String> hdrs=new HashMap<String, String>(); hdrs.put("sender", msg.getSrc().toString()); sendToClients(hdrs, msg.getRawBuffer(), msg.getOffset(), msg.getLength()); } break; } switch(hdr.type) { case MESSAGE: sendToClients(hdr.headers, msg.getRawBuffer(), msg.getOffset(), msg.getLength()); break; case ENDPOINT: String tmp_endpoint=hdr.headers.get("endpoint"); if(tmp_endpoint != null) { boolean update_clients; String old_endpoint=null; synchronized(endpoints) { endpoints.put(msg.getSrc(), tmp_endpoint); } update_clients=old_endpoint == null || !old_endpoint.equals(tmp_endpoint); if(update_clients && this.send_info) { synchronized(connections) { for(Connection conn: connections) { conn.writeResponse(ServerVerb.INFO, "endpoints", getAllEndpoints()); } } } } return null; default: throw new IllegalArgumentException("type " + hdr.type + " is not known"); } break; case Event.VIEW_CHANGE: handleView((View)evt.getArg()); break; } return up_prot.up(evt); } public static Frame readFrame(DataInputStream in) throws IOException { String verb=Util.readLine(in); if(verb == null) throw new EOFException("reading verb"); if(verb.length() == 0) return null; verb=verb.trim(); Map<String,String> headers=new HashMap<String,String>(); byte[] body=null; for(;;) { String header=Util.readLine(in); if(header == null) throw new EOFException("reading header"); if(header.length() == 0) break; int index=header.indexOf(":"); if(index != -1) headers.put(header.substring(0, index).trim(), header.substring(index+1).trim()); } if(headers.containsKey("content-length")) { int length=Integer.parseInt(headers.get("content-length")); body=new byte[length]; in.read(body, 0, body.length); } else { ByteBuffer buf=ByteBuffer.allocate(500); boolean terminate=false; for(;;) { int c=in.read(); if(c == -1 || c == 0) terminate=true; if(buf.remaining() == 0 || terminate) { if(body == null) { body=new byte[buf.position()]; System.arraycopy(buf.array(), buf.arrayOffset(), body, 0, buf.position()); } else { byte[] tmp=new byte[body.length + buf.position()]; System.arraycopy(body, 0, tmp, 0, body.length); try { System.arraycopy(buf.array(), buf.arrayOffset(), tmp, body.length, buf.position()); } catch(Throwable t) { } body=tmp; } buf.rewind(); } if(terminate) break; buf.put((byte)c); } } return new Frame(verb, headers, body); } protected void handleView(View view) { broadcastEndpoint(); List<Address> mbrs=view.getMembers(); this.view=view; synchronized(endpoints) { endpoints.keySet().retainAll(mbrs); } synchronized(connections) { for(Connection conn: connections) conn.sendInfo(); } } private String getAddress() { InetSocketAddress saddr=(InetSocketAddress)srv_sock.getLocalSocketAddress(); InetAddress tmp=saddr.getAddress(); if(!tmp.isAnyLocalAddress()) return tmp.getHostAddress() + ":" + srv_sock.getLocalPort(); for(Util.AddressScope scope: Util.AddressScope.values()) { try { InetAddress addr=Util.getAddress(scope); if(addr != null) return addr.getHostAddress() + ":" + srv_sock.getLocalPort(); } catch(SocketException e) { } } return null; } protected String getAllEndpoints() { synchronized(endpoints) { return Util.printListWithDelimiter(endpoints.values(), ","); } } // protected String getAllClients() { // StringBuilder sb=new StringBuilder(); // boolean first=true; // // synchronized(connections) { // for(Connection conn: connections) { // UUID session_id=conn.session_id; // if(session_id != null) { // if(first) // first=false; // else // sb.append(","); // sb.append(session_id); // } // } // } // // return sb.toString(); // } protected void broadcastEndpoint() { if(endpoint != null) { Message msg=new Message(); msg.putHeader(id, StompHeader.createHeader(StompHeader.Type.ENDPOINT, "endpoint", endpoint)); down_prot.down(new Event(Event.MSG, msg)); } } // private void sendToClients(String destination, String sender, byte[] buffer, int offset, int length) { // int len=50 + length + (ServerVerb.MESSAGE.name().length() + 2) // + (destination != null? destination.length()+ 2 : 0) // + (sender != null? sender.length() +2 : 0) // + (buffer != null? 20 : 0); // // ByteBuffer buf=ByteBuffer.allocate(len); // // StringBuilder sb=new StringBuilder(ServerVerb.MESSAGE.name()).append("\n"); // if(destination != null) // sb.append("destination: ").append(destination).append("\n"); // if(sender != null) // sb.append("sender: ").append(sender).append("\n"); // if(buffer != null) // sb.append("content-length: ").append(String.valueOf(length)).append("\n"); // sb.append("\n"); // // byte[] tmp=sb.toString().getBytes(); // // if(buffer != null) { // buf.put(tmp, 0, tmp.length); // buf.put(buffer, offset, length); // } // buf.put(NULL_BYTE); // // final Set<Connection> target_connections=new HashSet<Connection>(); // if(destination == null) { // synchronized(connections) { // target_connections.addAll(connections); // } // } // else { // if(!exact_destination_match) { // for(Map.Entry<String,Set<Connection>> entry: subscriptions.entrySet()) { // if(entry.getKey().startsWith(destination)) // target_connections.addAll(entry.getValue()); // } // } // else { // Set<Connection> conns=subscriptions.get(destination); // if(conns != null) // target_connections.addAll(conns); // } // } // // for(Connection conn: target_connections) // conn.writeResponse(buf.array(), buf.arrayOffset(), buf.position()); // } private void sendToClients(Map<String,String> headers, byte[] buffer, int offset, int length) { int len=50 + length + (ServerVerb.MESSAGE.name().length() + 2); if(headers != null) { for(Map.Entry<String,String> entry: headers.entrySet()) { len+=entry.getKey().length() +2; len+=entry.getValue().length() +2; len+=5; // fill chars, such as ": " or "\n" } } len+=(buffer != null? 20 : 0); ByteBuffer buf=ByteBuffer.allocate(len); StringBuilder sb=new StringBuilder(ServerVerb.MESSAGE.name()).append("\n"); if(headers != null) { for(Map.Entry<String,String> entry: headers.entrySet()) sb.append(entry.getKey()).append(": ").append(entry.getValue()).append("\n"); } if(buffer != null) sb.append("content-length: ").append(String.valueOf(length)).append("\n"); sb.append("\n"); byte[] tmp=sb.toString().getBytes(); if(buffer != null) { buf.put(tmp, 0, tmp.length); buf.put(buffer, offset, length); } buf.put(NULL_BYTE); final Set<Connection> target_connections=new HashSet<Connection>(); String destination=headers != null? headers.get("destination") : null; if(destination == null) { synchronized(connections) { target_connections.addAll(connections); } } else { if(!exact_destination_match) { for(Map.Entry<String,Set<Connection>> entry: subscriptions.entrySet()) { if(entry.getKey().startsWith(destination)) target_connections.addAll(entry.getValue()); } } else { Set<Connection> conns=subscriptions.get(destination); if(conns != null) target_connections.addAll(conns); } } for(Connection conn: target_connections) conn.writeResponse(buf.array(), buf.arrayOffset(), buf.position()); } /** * Class which handles a connection to a client */ public class Connection implements Runnable { protected final Socket sock; protected final DataInputStream in; protected final DataOutputStream out; protected final UUID session_id=UUID.randomUUID(); public Connection(Socket sock) throws IOException { this.sock=sock; this.in=new DataInputStream(sock.getInputStream()); this.out=new DataOutputStream(sock.getOutputStream()); } public void stop() { if(log.isTraceEnabled()) log.trace("closing connection to " + sock.getRemoteSocketAddress()); Util.close(in); Util.close(out); Util.close(sock); } protected void remove() { synchronized(connections) { connections.remove(this); } for(Set<Connection> conns: subscriptions.values()) { conns.remove(this); } for(Iterator<Map.Entry<String,Set<Connection>>> it=subscriptions.entrySet().iterator(); it.hasNext();) { Map.Entry<String,Set<Connection>> entry=it.next(); if(entry.getValue().isEmpty()) it.remove(); } } public void run() { while(!sock.isClosed()) { try { Frame frame=readFrame(in); if(frame != null) { if(log.isTraceEnabled()) log.trace(frame); handleFrame(frame); } } catch(IOException ex) { stop(); remove(); } catch(Throwable t) { log.error("failure reading frame", t); } } } protected void handleFrame(Frame frame) { Map<String,String> headers=frame.getHeaders(); ClientVerb verb=ClientVerb.valueOf(frame.getVerb()); switch(verb) { case CONNECT: writeResponse(ServerVerb.CONNECTED, "session-id", session_id.toString(), "password-check", "none"); break; case SEND: if(!headers.containsKey("sender")) { headers.put("sender", session_id.toString()); } Message msg=new Message(null, null, frame.getBody()); Header hdr=StompHeader.createHeader(StompHeader.Type.MESSAGE, headers); msg.putHeader(id, hdr); down_prot.down(new Event(Event.MSG, msg)); String receipt=headers.get("receipt"); if(receipt != null) writeResponse(ServerVerb.RECEIPT, "receipt-id", receipt); break; case SUBSCRIBE: String destination=headers.get("destination"); if(destination != null) { Set<Connection> conns=subscriptions.get(destination); if(conns == null) { conns=new HashSet<Connection>(); Set<Connection> tmp=subscriptions.putIfAbsent(destination, conns); if(tmp != null) conns=tmp; } conns.add(this); } break; case UNSUBSCRIBE: destination=headers.get("destination"); if(destination != null) { Set<Connection> conns=subscriptions.get(destination); if(conns != null) { if(conns.remove(this) && conns.isEmpty()) subscriptions.remove(destination); } } break; case BEGIN: break; case COMMIT: break; case ABORT: break; case ACK: break; case DISCONNECT: break; default: log.error("Verb " + frame.getVerb() + " is not handled"); break; } } public void sendInfo() { if(send_info) { writeResponse(ServerVerb.INFO, "local_addr", local_addr != null? local_addr.toString() : "n/a", "view", view.toString(), "endpoints", getAllEndpoints()); // "clients", getAllClients()); } } /** * Sends back a response. The keys_and_values vararg array needs to have an even number of elements * @param response * @param keys_and_values */ private void writeResponse(ServerVerb response, String ... keys_and_values) { String tmp=response.name(); try { out.write(tmp.getBytes()); out.write('\n'); for(int i=0; i < keys_and_values.length; i++) { String key=keys_and_values[i]; String val=keys_and_values[++i]; out.write((key + ": " + val + "\n").getBytes()); } out.write("\n".getBytes()); out.write(NULL_BYTE); out.flush(); } catch(IOException ex) { log.error("failed writing response " + response + ": " + ex); } } private void writeResponse(byte[] response, int offset, int length) { try { out.write(response, offset, length); out.flush(); } catch(IOException ex) { log.error("failed writing response: " + ex); } } } public static class Frame { final String verb; final Map<String,String> headers; final byte[] body; public Frame(String verb, Map<String, String> headers, byte[] body) { this.verb=verb; this.headers=headers; this.body=body; } public byte[] getBody() { return body; } public Map<String, String> getHeaders() { return headers; } public String getVerb() { return verb; } public String toString() { StringBuilder sb=new StringBuilder(); sb.append(verb).append("\n"); if(headers != null && !headers.isEmpty()) { for(Map.Entry<String,String> entry: headers.entrySet()) sb.append(entry.getKey()).append(": ").append(entry.getValue()).append("\n"); } if(body != null && body.length > 0) { sb.append("body: "); if(body.length < 50) sb.append(new String(body)).append(" (").append(body.length).append(" bytes)"); else sb.append(body.length).append(" bytes"); } return sb.toString(); } } public static class StompHeader extends org.jgroups.Header { public static enum Type {MESSAGE, ENDPOINT} protected Type type; protected final Map<String,String> headers=new HashMap<String,String>(); public StompHeader() { } private StompHeader(Type type) { this.type=type; } /** * Creates a new header * @param type * @param headers Keys and values to be added to the header hashmap. Needs to be an even number * @return */ public static StompHeader createHeader(Type type, String ... headers) { StompHeader retval=new StompHeader(type); if(headers != null) { for(int i=0; i < headers.length; i++) { String key=headers[i]; String value=headers[++i]; retval.headers.put(key, value); } } return retval; } public static StompHeader createHeader(Type type, Map<String,String> headers) { StompHeader retval=new StompHeader(type); if(headers != null) retval.headers.putAll(headers); return retval; } public int size() { int retval=Global.INT_SIZE *2; // type + size of hashmap for(Map.Entry<String,String> entry: headers.entrySet()) { retval+=entry.getKey().length() +2; retval+=entry.getValue().length() +2; } return retval; } public void writeTo(DataOutput out) throws Exception { out.writeInt(type.ordinal()); out.writeInt(headers.size()); for(Map.Entry<String,String> entry: headers.entrySet()) { out.writeUTF(entry.getKey()); out.writeUTF(entry.getValue()); } } public void readFrom(DataInput in) throws Exception { type=Type.values()[in.readInt()]; int size=in.readInt(); for(int i=0; i < size; i++) { String key=in.readUTF(); String value=in.readUTF(); headers.put(key, value); } } public String toString() { StringBuilder sb=new StringBuilder(type.toString()); sb.append("headers: ").append(headers); return sb.toString(); } } }