/*
* Copyright 2013-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.aws.messaging.listener;
import com.amazonaws.services.sqs.model.DeleteMessageRequest;
import com.amazonaws.services.sqs.model.Message;
import com.amazonaws.services.sqs.model.ReceiveMessageResult;
import org.springframework.core.task.AsyncTaskExecutor;
import org.springframework.messaging.MessagingException;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import static org.springframework.cloud.aws.messaging.core.QueueMessageUtils.createMessage;
/**
* @author Agim Emruli
* @author Alain Sahli
* @since 1.0
*/
public class SimpleMessageListenerContainer extends AbstractMessageListenerContainer {
private static final int DEFAULT_WORKER_THREADS = 2;
private static final String DEFAULT_THREAD_NAME_PREFIX =
ClassUtils.getShortName(SimpleMessageListenerContainer.class) + "-";
private boolean defaultTaskExecutor;
private long backOffTime = 10000;
private long queueStopTimeout = 10000;
private AsyncTaskExecutor taskExecutor;
private ConcurrentHashMap<String, Future<?>> scheduledFutureByQueue;
private ConcurrentHashMap<String, Boolean> runningStateByQueue;
protected AsyncTaskExecutor getTaskExecutor() {
return this.taskExecutor;
}
public void setTaskExecutor(AsyncTaskExecutor taskExecutor) {
this.taskExecutor = taskExecutor;
}
/**
* @return The number of milliseconds the polling thread must wait before trying to recover when an error occurs
* (e.g. connection timeout)
*/
public long getBackOffTime() {
return this.backOffTime;
}
/**
* The number of milliseconds the polling thread must wait before trying to recover when an error occurs
* (e.g. connection timeout). Default is 10000 milliseconds.
*
* @param backOffTime
* in milliseconds
*/
public void setBackOffTime(long backOffTime) {
this.backOffTime = backOffTime;
}
/**
* @return The number of milliseconds the {@link SimpleMessageListenerContainer#stop(String)} method waits for a queue
* to stop before interrupting the current thread. Default value is 10000 milliseconds (10 seconds).
*/
public long getQueueStopTimeout() {
return this.queueStopTimeout;
}
/**
* The number of milliseconds the {@link SimpleMessageListenerContainer#stop(String)} method waits for a queue
* to stop before interrupting the current thread. Default value is 10000 milliseconds (10 seconds).
*
* @param queueStopTimeout
* in milliseconds
*/
public void setQueueStopTimeout(long queueStopTimeout) {
this.queueStopTimeout = queueStopTimeout;
}
@Override
protected void initialize() {
if (this.taskExecutor == null) {
this.defaultTaskExecutor = true;
this.taskExecutor = createDefaultTaskExecutor();
}
super.initialize();
initializeRunningStateByQueue();
this.scheduledFutureByQueue = new ConcurrentHashMap<>(getRegisteredQueues().size());
}
private void initializeRunningStateByQueue() {
this.runningStateByQueue = new ConcurrentHashMap<>(getRegisteredQueues().size());
for (String queueName : getRegisteredQueues().keySet()) {
this.runningStateByQueue.put(queueName, false);
}
}
@Override
protected void doStart() {
synchronized (this.getLifecycleMonitor()) {
scheduleMessageListeners();
}
}
@Override
protected void doStop() {
notifyRunningQueuesToStop();
waitForRunningQueuesToStop();
}
private void notifyRunningQueuesToStop() {
for (Map.Entry<String, Boolean> runningStateByQueue : this.runningStateByQueue.entrySet()) {
if (runningStateByQueue.getValue()) {
stopQueue(runningStateByQueue.getKey());
}
}
}
private void waitForRunningQueuesToStop() {
for (Map.Entry<String, Boolean> queueRunningState : this.runningStateByQueue.entrySet()) {
String logicalQueueName = queueRunningState.getKey();
Future<?> queueSpinningThread = this.scheduledFutureByQueue.get(logicalQueueName);
if (queueSpinningThread != null) {
try {
queueSpinningThread.get(getQueueStopTimeout(), TimeUnit.SECONDS);
} catch (ExecutionException | TimeoutException e) {
getLogger().warn("An exception occurred while stopping queue '" + logicalQueueName + "'", e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
}
}
@Override
protected void doDestroy() {
if (this.defaultTaskExecutor) {
((ThreadPoolTaskExecutor) this.taskExecutor).destroy();
}
}
/**
* Create a default TaskExecutor. Called if no explicit TaskExecutor has been specified.
* <p>The default implementation builds a {@link org.springframework.core.task.SimpleAsyncTaskExecutor}
* with the specified bean name (or the class name, if no bean name specified) as thread name prefix.
*
* @return a {@link org.springframework.core.task.SimpleAsyncTaskExecutor} configured with the thread name prefix
* @see org.springframework.core.task.SimpleAsyncTaskExecutor#SimpleAsyncTaskExecutor(String)
*/
protected AsyncTaskExecutor createDefaultTaskExecutor() {
String beanName = getBeanName();
ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor();
threadPoolTaskExecutor.setThreadNamePrefix(beanName != null ? beanName + "-" : DEFAULT_THREAD_NAME_PREFIX);
int spinningThreads = this.getRegisteredQueues().size();
if (spinningThreads > 0) {
threadPoolTaskExecutor.setCorePoolSize(spinningThreads * DEFAULT_WORKER_THREADS);
int maxNumberOfMessagePerBatch = getMaxNumberOfMessages() != null ? getMaxNumberOfMessages() : DEFAULT_WORKER_THREADS;
threadPoolTaskExecutor.setMaxPoolSize(spinningThreads * maxNumberOfMessagePerBatch);
}
// No use of a thread pool executor queue to avoid retaining message to long in memory
threadPoolTaskExecutor.setQueueCapacity(0);
threadPoolTaskExecutor.afterPropertiesSet();
return threadPoolTaskExecutor;
}
private void scheduleMessageListeners() {
for (Map.Entry<String, QueueAttributes> registeredQueue : getRegisteredQueues().entrySet()) {
startQueue(registeredQueue.getKey(), registeredQueue.getValue());
}
}
protected void executeMessage(org.springframework.messaging.Message<String> stringMessage) {
getMessageHandler().handleMessage(stringMessage);
}
/**
* Stops and waits until the specified queue has stopped. If the wait timeout specified by {@link SimpleMessageListenerContainer#getQueueStopTimeout()}
* is reached, the current thread is interrupted.
*
* @param logicalQueueName
* the name as defined on the listener method
*/
public void stop(String logicalQueueName) {
stopQueue(logicalQueueName);
try {
if (isRunning(logicalQueueName)) {
Future<?> future = this.scheduledFutureByQueue.remove(logicalQueueName);
future.get(this.queueStopTimeout, TimeUnit.MILLISECONDS);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (ExecutionException | TimeoutException e) {
getLogger().warn("Error stopping queue with name: '" + logicalQueueName + "'", e);
}
}
protected void stopQueue(String logicalQueueName) {
Assert.isTrue(this.runningStateByQueue.containsKey(logicalQueueName), "Queue with name '" + logicalQueueName + "' does not exist");
this.runningStateByQueue.put(logicalQueueName, false);
}
public void start(String logicalQueueName) {
Assert.isTrue(this.runningStateByQueue.containsKey(logicalQueueName), "Queue with name '" + logicalQueueName + "' does not exist");
QueueAttributes queueAttributes = this.getRegisteredQueues().get(logicalQueueName);
startQueue(logicalQueueName, queueAttributes);
}
/**
* Checks if the spinning thread for the specified queue {@code logicalQueueName} is still running (polling for new
* messages) or not.
*
* @param logicalQueueName
* the name as defined on the listener method
* @return {@code true} if the spinning thread for the specified queue is running otherwise {@code false}.
*/
public boolean isRunning(String logicalQueueName) {
Future<?> future = this.scheduledFutureByQueue.get(logicalQueueName);
return future != null && !future.isCancelled() && !future.isDone();
}
protected void startQueue(String queueName, QueueAttributes queueAttributes) {
if (this.runningStateByQueue.containsKey(queueName) && this.runningStateByQueue.get(queueName)) {
return;
}
this.runningStateByQueue.put(queueName, true);
Future<?> future = getTaskExecutor().submit(new AsynchronousMessageListener(queueName, queueAttributes));
this.scheduledFutureByQueue.put(queueName, future);
}
private class AsynchronousMessageListener implements Runnable {
private final QueueAttributes queueAttributes;
private final String logicalQueueName;
private AsynchronousMessageListener(String logicalQueueName, QueueAttributes queueAttributes) {
this.logicalQueueName = logicalQueueName;
this.queueAttributes = queueAttributes;
}
@Override
public void run() {
while (isQueueRunning()) {
try {
ReceiveMessageResult receiveMessageResult = getAmazonSqs().receiveMessage(this.queueAttributes.getReceiveMessageRequest());
CountDownLatch messageBatchLatch = new CountDownLatch(receiveMessageResult.getMessages().size());
for (Message message : receiveMessageResult.getMessages()) {
if (isQueueRunning()) {
MessageExecutor messageExecutor = new MessageExecutor(this.logicalQueueName, message, this.queueAttributes);
getTaskExecutor().execute(new SignalExecutingRunnable(messageBatchLatch, messageExecutor));
} else {
messageBatchLatch.countDown();
}
}
try {
messageBatchLatch.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
} catch (Exception e) {
getLogger().warn("An Exception occurred while polling queue '{}'. The failing operation will be " +
"retried in {} milliseconds", this.logicalQueueName, getBackOffTime(), e);
try {
//noinspection BusyWait
Thread.sleep(getBackOffTime());
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
}
}
}
SimpleMessageListenerContainer.this.scheduledFutureByQueue.remove(this.logicalQueueName);
}
private boolean isQueueRunning() {
if (SimpleMessageListenerContainer.this.runningStateByQueue.containsKey(this.logicalQueueName)) {
return SimpleMessageListenerContainer.this.runningStateByQueue.get(this.logicalQueueName);
} else {
getLogger().warn("Stopped queue '" + this.logicalQueueName + "' because it was not listed as running queue.");
return false;
}
}
}
private class MessageExecutor implements Runnable {
private final Message message;
private final String logicalQueueName;
private final String queueUrl;
private final boolean hasRedrivePolicy;
private final SqsMessageDeletionPolicy deletionPolicy;
private MessageExecutor(String logicalQueueName, Message message, QueueAttributes queueAttributes) {
this.logicalQueueName = logicalQueueName;
this.message = message;
this.queueUrl = queueAttributes.getReceiveMessageRequest().getQueueUrl();
this.hasRedrivePolicy = queueAttributes.hasRedrivePolicy();
this.deletionPolicy = queueAttributes.getDeletionPolicy();
}
@Override
public void run() {
String receiptHandle = this.message.getReceiptHandle();
org.springframework.messaging.Message<String> queueMessage = getMessageForExecution();
try {
executeMessage(queueMessage);
applyDeletionPolicyOnSuccess(receiptHandle);
} catch (MessagingException messagingException) {
applyDeletionPolicyOnError(receiptHandle, messagingException);
}
}
private void applyDeletionPolicyOnSuccess(String receiptHandle) {
if (this.deletionPolicy == SqsMessageDeletionPolicy.ON_SUCCESS ||
this.deletionPolicy == SqsMessageDeletionPolicy.ALWAYS ||
this.deletionPolicy == SqsMessageDeletionPolicy.NO_REDRIVE) {
deleteMessage(receiptHandle);
}
}
private void applyDeletionPolicyOnError(String receiptHandle, MessagingException messagingException) {
if (this.deletionPolicy == SqsMessageDeletionPolicy.ALWAYS ||
(this.deletionPolicy == SqsMessageDeletionPolicy.NO_REDRIVE && !this.hasRedrivePolicy)) {
deleteMessage(receiptHandle);
} else if (this.deletionPolicy == SqsMessageDeletionPolicy.ON_SUCCESS) {
getLogger().error("Exception encountered while processing message.", messagingException);
}
}
private void deleteMessage(String receiptHandle) {
getAmazonSqs().deleteMessageAsync(new DeleteMessageRequest(this.queueUrl, receiptHandle));
}
private org.springframework.messaging.Message<String> getMessageForExecution() {
HashMap<String, Object> additionalHeaders = new HashMap<>();
additionalHeaders.put(QueueMessageHandler.LOGICAL_RESOURCE_ID, this.logicalQueueName);
if (this.deletionPolicy == SqsMessageDeletionPolicy.NEVER) {
String receiptHandle = this.message.getReceiptHandle();
QueueMessageAcknowledgment acknowledgment = new QueueMessageAcknowledgment(SimpleMessageListenerContainer.this.getAmazonSqs(), this.queueUrl, receiptHandle);
additionalHeaders.put(QueueMessageHandler.ACKNOWLEDGMENT, acknowledgment);
}
return createMessage(this.message, additionalHeaders);
}
}
private static class SignalExecutingRunnable implements Runnable {
private final CountDownLatch countDownLatch;
private final Runnable runnable;
private SignalExecutingRunnable(CountDownLatch endSignal, Runnable runnable) {
this.countDownLatch = endSignal;
this.runnable = runnable;
}
@Override
public void run() {
try {
this.runnable.run();
} finally {
this.countDownLatch.countDown();
}
}
}
}