/*
* Copyright 2014-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.integration.scattergather;
import org.springframework.aop.support.AopUtils;
import org.springframework.context.Lifecycle;
import org.springframework.integration.channel.FixedSubscriberChannel;
import org.springframework.integration.channel.QueueChannel;
import org.springframework.integration.context.IntegrationContextUtils;
import org.springframework.integration.core.MessageProducer;
import org.springframework.integration.endpoint.AbstractEndpoint;
import org.springframework.integration.endpoint.EventDrivenConsumer;
import org.springframework.integration.endpoint.PollingConsumer;
import org.springframework.integration.handler.AbstractReplyProducingMessageHandler;
import org.springframework.integration.support.channel.HeaderChannelRegistry;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageDeliveryException;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.PollableChannel;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
/**
* The {@link MessageHandler} implementation for the
* <a href="http://www.eaipatterns.com/BroadcastAggregate.html">Scatter-Gather</a> EIP pattern.
*
* @author Artem Bilan
* @since 4.1
*/
public class ScatterGatherHandler extends AbstractReplyProducingMessageHandler implements Lifecycle {
private static final String GATHER_RESULT_CHANNEL = "gatherResultChannel";
private final MessageChannel scatterChannel;
private final MessageHandler gatherer;
private MessageChannel gatherChannel;
private long gatherTimeout = -1;
private AbstractEndpoint gatherEndpoint;
private HeaderChannelRegistry replyChannelRegistry;
public ScatterGatherHandler(MessageChannel scatterChannel, MessageHandler gatherer) {
Assert.notNull(scatterChannel, "'scatterChannel' must not be null");
Assert.notNull(gatherer, "'gatherer' must not be null");
Class<?> gathererClass = AopUtils.getTargetClass(gatherer);
checkClass(gathererClass, "org.springframework.integration.aggregator.AggregatingMessageHandler", "gatherer");
this.scatterChannel = scatterChannel;
this.gatherer = gatherer;
}
public ScatterGatherHandler(MessageHandler scatterer, MessageHandler gatherer) {
this(new FixedSubscriberChannel(scatterer), gatherer);
Assert.notNull(scatterer, "'scatterer' must not be null");
Class<?> scattererClass = AopUtils.getTargetClass(scatterer);
checkClass(scattererClass, "org.springframework.integration.router.RecipientListRouter", "scatterer");
}
public void setGatherChannel(MessageChannel gatherChannel) {
this.gatherChannel = gatherChannel;
}
public void setGatherTimeout(long gatherTimeout) {
this.gatherTimeout = gatherTimeout;
}
@Override
protected void doInit() {
if (this.gatherChannel == null) {
this.gatherChannel = new FixedSubscriberChannel(this.gatherer);
}
else {
if (this.gatherChannel instanceof SubscribableChannel) {
this.gatherEndpoint = new EventDrivenConsumer((SubscribableChannel) this.gatherChannel, this.gatherer);
}
else if (this.gatherChannel instanceof PollableChannel) {
this.gatherEndpoint = new PollingConsumer((PollableChannel) this.gatherChannel, this.gatherer);
((PollingConsumer) this.gatherEndpoint).setReceiveTimeout(this.gatherTimeout);
}
else {
throw new MessagingException("Unsupported 'replyChannel' type [" + this.gatherChannel.getClass() + "]."
+ "SubscribableChannel or PollableChannel type are supported.");
}
this.gatherEndpoint.setBeanFactory(this.getBeanFactory());
this.gatherEndpoint.afterPropertiesSet();
}
((MessageProducer) this.gatherer).setOutputChannel(new FixedSubscriberChannel(message -> {
MessageHeaders headers = message.getHeaders();
if (headers.containsKey(GATHER_RESULT_CHANNEL)) {
Object gatherResultChannel = headers.get(GATHER_RESULT_CHANNEL);
if (gatherResultChannel instanceof MessageChannel) {
messagingTemplate.send((MessageChannel) gatherResultChannel, message);
}
else if (gatherResultChannel instanceof String) {
messagingTemplate.send((String) gatherResultChannel, message);
}
}
else {
throw new MessageDeliveryException(message,
"The 'gatherResultChannel' header is required to delivery gather result.");
}
}));
this.replyChannelRegistry = getBeanFactory()
.getBean(IntegrationContextUtils.INTEGRATION_HEADER_CHANNEL_REGISTRY_BEAN_NAME,
HeaderChannelRegistry.class);
}
@Override
protected Object handleRequestMessage(Message<?> requestMessage) {
PollableChannel gatherResultChannel = new QueueChannel();
Object gatherResultChannelName = this.replyChannelRegistry.channelToChannelName(gatherResultChannel);
Message<?> scatterMessage = getMessageBuilderFactory()
.fromMessage(requestMessage)
.setHeader(GATHER_RESULT_CHANNEL, gatherResultChannelName)
.setReplyChannel(this.gatherChannel)
.build();
this.messagingTemplate.send(this.scatterChannel, scatterMessage);
Message<?> gatherResult = gatherResultChannel.receive(this.gatherTimeout);
if (gatherResult != null) {
return gatherResult;
}
return null;
}
@Override
public void start() {
if (this.gatherEndpoint != null) {
this.gatherEndpoint.start();
}
}
@Override
public void stop() {
if (this.gatherEndpoint != null) {
this.gatherEndpoint.start();
}
}
@Override
public boolean isRunning() {
return this.gatherEndpoint == null || this.gatherEndpoint.isRunning();
}
private void checkClass(Class<?> gathererClass, String className, String type) throws LinkageError {
Class<?> clazz = null;
try {
clazz = ClassUtils.forName(className, getClass().getClassLoader());
}
catch (Exception e) {
}
Assert.isAssignable(clazz, gathererClass, "the '" + type + "' must be an " + className + " instance");
}
}