/* * Copyright (C) 2012 Red Hat, Inc. and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.jboss.errai.bus.server.cluster.jgroups; import static org.jboss.errai.bus.server.cluster.ClusterParts.BusId; import static org.jboss.errai.bus.server.cluster.ClusterParts.MessageId; import static org.jboss.errai.bus.server.cluster.ClusterParts.Payload; import static org.jboss.errai.bus.server.cluster.ClusterParts.SessId; import static org.jboss.errai.bus.server.cluster.ClusterParts.Subject; import static org.jboss.errai.common.client.protocols.MessageParts.CommandType; import static org.jboss.errai.common.client.protocols.MessageParts.SessionID; import static org.jboss.errai.common.client.protocols.MessageParts.ToSubject; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; import com.google.inject.Inject; import org.jboss.errai.bus.client.api.QueueSession; import org.jboss.errai.bus.client.api.RoutingFlag; import org.jboss.errai.bus.client.api.base.CommandMessage; import org.jboss.errai.bus.client.api.messaging.Message; import org.jboss.errai.bus.client.api.messaging.MessageCallback; import org.jboss.errai.bus.server.QueueUnavailableException; import org.jboss.errai.bus.server.api.MessageQueue; import org.jboss.errai.bus.server.api.ServerMessageBus; import org.jboss.errai.bus.server.cluster.ClusterCommands; import org.jboss.errai.bus.server.cluster.ClusterParts; import org.jboss.errai.bus.server.cluster.ClusteringProvider; import org.jboss.errai.bus.server.cluster.IntrabusQueueSession; import org.jboss.errai.bus.server.io.MessageFactory; import org.jboss.errai.bus.server.service.ErraiConfigAttribs; import org.jboss.errai.bus.server.service.ErraiService; import org.jboss.errai.bus.server.service.ErraiServiceConfigurator; import org.jboss.errai.bus.server.util.SecureHashUtil; import org.jboss.errai.common.client.protocols.Resources; import org.jboss.errai.marshalling.client.protocols.ErraiProtocol; import org.jgroups.Address; import org.jgroups.JChannel; import org.jgroups.ReceiverAdapter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * * @author Mike Brock */ public class JGroupsClusteringProvider extends ReceiverAdapter implements ClusteringProvider, MessageCallback { private static final String CLUSTER_SERVICE = "local:ErraiClusterService"; // erraibus service private final String busId = SecureHashUtil.nextSecureHash(); private final JChannel jchannel; private final ServerMessageBus serverMessageBus; final Cache<String, Address> sessionToNodeCache; private final static String JGROUPS_MESSAGE_RESOURCE = "JGroupsMessage"; private static Logger log = LoggerFactory.getLogger(JGroupsClusteringProvider.class); @Inject private JGroupsClusteringProvider(final ServerMessageBus messageBus, final ErraiServiceConfigurator config, final ErraiService erraiService) { this.serverMessageBus = messageBus; try { jchannel = new JChannel(JGroupsConfigAttribs.JGROUPS_PROTOCOL_STACK.get(config)); jchannel.connect(ErraiConfigAttribs.CLUSTER_NAME.get(config)); // I don't think waiting for the state is necessary. // jchannel.getState(null, 2000); } catch (Exception e) { throw new RuntimeException(e); } serverMessageBus.subscribe(CLUSTER_SERVICE, this); jchannel.setReceiver(this); erraiService.addShutdownHook(new Runnable() { @Override public void run() { jchannel.close(); log.info("shut down jgroups clustering service"); } }); sessionToNodeCache = CacheBuilder.newBuilder() .maximumSize(100) .build(); log.info("starting errai clustering service."); } @Override public void receive(final org.jgroups.Message msg) { try { final Message erraiMessage = getErraiMessage(msg); erraiMessage.setResource(JGROUPS_MESSAGE_RESOURCE, msg); if (busId.equals(erraiMessage.get(String.class, BusId))) { return; } erraiMessage.setFlag(RoutingFlag.FromPeer); serverMessageBus.sendGlobal(erraiMessage); } catch (Exception e) { e.printStackTrace(); } } @Override public void callback(final Message message) { final QueueSession queueSession = message.getResource(QueueSession.class, "Session"); if (queueSession != IntrabusQueueSession.INSTANCE) { log.warn("message to cluster service ('" + CLUSTER_SERVICE + "') originating from illegal session. " + " message was discarded."); return; } switch (ClusterCommands.valueOf(message.getCommandType())) { case WhoHandles: { final String subject = message.get(String.class, Subject); if (serverMessageBus.hasRemoteSubscriptions(subject)) { final String sessionIdRequested = message.get(String.class, ClusterParts.SessId); try { if (serverMessageBus.getQueueBySession(sessionIdRequested) == null) { return; } } catch (QueueUnavailableException e) { return; } final org.jgroups.Message jgroupsMessage = message.getResource(org.jgroups.Message.class, JGROUPS_MESSAGE_RESOURCE); final Message replyMsg = CommandMessage.create() .set(ToSubject, CLUSTER_SERVICE) .set(CommandType, ClusterCommands.NotifyOwner.name()) .set(BusId, busId) .copy(MessageId, message) .set(ClusterParts.SessId, sessionIdRequested); try { jchannel.send(jgroupsMessage.getSrc(), ErraiProtocol.encodePayload(replyMsg.getParts())); } catch (Exception e) { e.printStackTrace(); } } } break; case NotifyOwner: { final String messageId = message.get(String.class, MessageId); final String sessId = message.get(String.class, SessId); final Message deferredMessage = serverMessageBus.getDeadLetterMessage(messageId); serverMessageBus.removeDeadLetterMessage(messageId); final org.jgroups.Message jgroupsMessage = message.getResource(org.jgroups.Message.class, JGROUPS_MESSAGE_RESOURCE); sessionToNodeCache.put(sessId, jgroupsMessage.getSrc()); if (deferredMessage != null) { final Message dMessage = createForwardMessageFor(deferredMessage, messageId); try { jchannel.send(jgroupsMessage.getSrc(), ErraiProtocol.encodePayload(dMessage.getParts())); } catch (Exception e) { e.printStackTrace(); } } } break; case InvalidRoute: { final String sessionId = message.get(String.class, SessId); sessionToNodeCache.invalidate(sessionId); final String messageId = message.get(String.class, MessageId); final String subject = message.get(String.class, Subject); final Message whoMessage = createWhoHandlesMessage(sessionId, subject, messageId); try { jchannel.send(getJGroupsMessage(whoMessage)); } catch (Exception e) { e.printStackTrace(); } break; } case MessageForward: { final String payload = message.get(String.class, Payload); final Message forwardMessage = MessageFactory.createCommandMessage(IntrabusQueueSession.INSTANCE, payload); forwardMessage.setFlag(RoutingFlag.FromPeer); final String sessId = message.get(String.class, SessId); if (sessId == null) { serverMessageBus.sendGlobal(forwardMessage); } else { final MessageQueue messageQueue; try { messageQueue = serverMessageBus.getQueueBySession(sessId); } catch (QueueUnavailableException e) { final org.jgroups.Message jgroupsMessage = message.getResource(org.jgroups.Message.class, JGROUPS_MESSAGE_RESOURCE); final String messageId = message.get(String.class, MessageId); final Message invalidRoute = createInvalidRouteMessage(sessId, forwardMessage.getSubject(), messageId); try { jchannel.send(jgroupsMessage.getSrc(), ErraiProtocol.encodePayload(invalidRoute.getParts())); } catch (Exception e2) { e2.printStackTrace(); } return; } // otherwise route it directly to the client. forwardMessage.setResource(Resources.Session.name(), messageQueue.getSession()); serverMessageBus.send(forwardMessage); } } break; } } @Override public void clusterTransmit(final String sessionId, final String subject, final String messageId) { final Address knownAddress = sessionToNodeCache.getIfPresent(sessionId); if (knownAddress != null) { final Message forwardMessage = createForwardMessageFor(serverMessageBus.getDeadLetterMessage(messageId), messageId); try { jchannel.send(knownAddress, ErraiProtocol.encodePayload(forwardMessage.getParts())); } catch (Exception e) { e.printStackTrace(); } } else { final Message whoHandlesMessage = createWhoHandlesMessage(sessionId, subject, messageId); try { jchannel.send(getJGroupsMessage(whoHandlesMessage)); } catch (Exception e) { e.printStackTrace(); } } } private Message createForwardMessageFor(final Message message, final String messageId) { final Message forward = CommandMessage.create() .set(ToSubject, CLUSTER_SERVICE) .set(CommandType, ClusterCommands.MessageForward.name()) .set(Payload, ErraiProtocol.encodePayload(message.getParts())) .set(BusId, busId); if (message.hasPart(SessionID)) { final String value = message.get(String.class, SessionID); if (!IntrabusQueueSession.INSTANCE.getSessionId().equals(value)) { forward.set(SessId, value); } } if (messageId != null) { forward.set(MessageId, messageId); } return forward; } private Message createInvalidRouteMessage(final String sessionId, final String subject, final String messageId) { return CommandMessage.create() .set(ToSubject, CLUSTER_SERVICE) .set(CommandType, ClusterCommands.InvalidRoute.name()) .set(SessId, sessionId) .set(Subject, subject) .set(MessageId, messageId) .set(BusId, busId); } private Message createWhoHandlesMessage(final String sessionId, final String subject, final String messageId) { return CommandMessage.create() .set(ToSubject, CLUSTER_SERVICE) .set(CommandType, ClusterCommands.WhoHandles.name()) .set(ClusterParts.SessId, sessionId) .set(BusId, busId) .set(Subject, subject) .set(MessageId, messageId); } @Override public void clusterTransmitGlobal(final Message message) { try { jchannel.send(getJGroupsMessage(createForwardMessageFor(message, null))); } catch (Exception e) { e.printStackTrace(); } } public static Message getErraiMessage(final org.jgroups.Message message) { return MessageFactory.createCommandMessage(IntrabusQueueSession.INSTANCE, String.valueOf(message.getObject())); } private static org.jgroups.Message getJGroupsMessage(final Message message) { return new org.jgroups.Message(null, null, ErraiProtocol.encodePayload(message.getParts())); } }