package com.netflix.schlep.sqs.consumer;
import java.io.ByteArrayInputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.services.sqs.model.BatchResultErrorEntry;
import com.amazonaws.services.sqs.model.DeleteMessageBatchRequestEntry;
import com.amazonaws.services.sqs.model.DeleteMessageBatchResult;
import com.google.common.collect.Lists;
import com.netflix.schlep.Completion;
import com.netflix.schlep.consumer.IncomingMessage;
import com.netflix.schlep.consumer.PollingMessageConsumer;
import com.netflix.schlep.exception.ConsumerException;
import com.netflix.schlep.mapper.Base64Serializer;
import com.netflix.schlep.mapper.Serializer;
import com.netflix.schlep.sqs.AmazonSqsClient;
import com.netflix.schlep.sqs.SqsMessage;
import com.netflix.schlep.util.UnstoppableStopwatch;
class SqsMessageConsumer extends PollingMessageConsumer {
private static final Logger LOG = LoggerFactory.getLogger(SqsMessageConsumer.class);
public static final long DEFAULT_VISIBILITY_TIMEOUT = TimeUnit.MINUTES.toSeconds(5);
public static final Serializer DEFAULT_SERIALIZER = new Base64Serializer();
/**
* Builder
*
* @param <T>
*/
public static abstract class Builder<T extends Builder<T>> extends PollingMessageConsumer.Builder<T> {
private long visibilityTimeout = DEFAULT_VISIBILITY_TIMEOUT;
private Serializer serializer = DEFAULT_SERIALIZER;
private AmazonSqsClient.Builder clientBuilder = AmazonSqsClient.builder();
public T withCredentials(AWSCredentials credentials) {
this.clientBuilder.withCredentials(credentials);
return self();
}
public T withVisibilityTimeout(long timeout) {
this.visibilityTimeout = timeout;
return self();
}
public T withSerializer(Serializer serializer) {
this.serializer = serializer;
return self();
}
public T withConnectionTimeout(int connectTimeout) {
clientBuilder.withConnectionTimeout(connectTimeout);
return self();
}
public T withReadTimeout(int readTimeout) {
clientBuilder.withReadTimeout(readTimeout);
return self();
}
public T withMaxConnections(int maxConnections) {
clientBuilder.withMaxConnections(maxConnections);
return self();
}
public T withMaxRetries(int retries) {
clientBuilder.withMaxRetries(retries);
return self();
}
public T withQueueName(String queueName) {
clientBuilder.withQueueName(queueName);
return self();
}
public T withRegion(String region) {
clientBuilder.withRegion(region);
return self();
}
public SqsMessageConsumer build() throws Exception {
return new SqsMessageConsumer(this);
}
@Override
public String toString() {
return "Builder [visibilityTimeout="
+ visibilityTimeout + ", serializer=" + serializer
+ ", clientBuilder=" + clientBuilder + "]";
}
}
/**
* BuilderWrapper to link with subclass Builder
* @author elandau
*
*/
private static class BuilderWrapper extends Builder<BuilderWrapper> {
@Override
protected BuilderWrapper self() {
return this;
}
}
public static Builder<?> builder() {
return new BuilderWrapper();
}
private final Serializer serializer;
private final long visibilityTimeout;
private final AmazonSqsClient client;
private final AtomicLong ackInvalid = new AtomicLong(0);
private final AtomicLong ackFailure = new AtomicLong(0);
private final AtomicLong ackSuccess = new AtomicLong(0);
protected SqsMessageConsumer(Builder<?> init) throws Exception {
super(init);
this.visibilityTimeout = init.visibilityTimeout;
this.serializer = init.serializer;
this.client = init.clientBuilder.build();
}
/**
* Read a batch of messages and dispatch them serially, one by one
* @param maxMessageCount
* @param attributes
* @return
* @throws ConsumerException
*/
@Override
protected List<IncomingMessage> readBatch(int batchSize) throws ConsumerException {
try {
long timeout = visibilityTimeout;
// Execute the request
Collection<SqsMessage> result = client.receiveMessages(batchSize, timeout, null);
UnstoppableStopwatch sw = new UnstoppableStopwatch();
// Transform to internal response
List<IncomingMessage> messages = Lists.newArrayList();
for (final SqsMessage message : result) {
messages.add(new SqsIncomingMessage(message, sw, visibilityTimeout) {
@Override
public <T> T getContents(Class<T> clazz) {
ByteArrayInputStream bais = new ByteArrayInputStream(message.getMessage().getBody().getBytes());
try {
return (T)serializer.deserialize(bais, clazz);
} catch (Exception e) {
LOG.error("Failed to deserialize message", e);
throw new RuntimeException("Bad data format", e);
}
}
});
}
return messages;
} catch (Exception e) {
throw new ConsumerException("Error consuming messages " + getId(), e);
}
}
@Override
protected void sendAckBatch(List<Completion<IncomingMessage>> messages) {
List<Completion<IncomingMessage>> toAck = Lists.newArrayList(messages);
while (!toAck.isEmpty()) {
try {
// Construct a delete message request and assign each message an ID equivalent to it's position
// in the original list for fast lookup on the response
final List<DeleteMessageBatchRequestEntry> batchReqEntries = new ArrayList<DeleteMessageBatchRequestEntry>(messages.size());
int id = 0;
for (Completion<IncomingMessage> message : messages) {
SqsIncomingMessage sqsMessage = (SqsIncomingMessage)(message.getValue());
batchReqEntries.add(new DeleteMessageBatchRequestEntry(
Integer.toString(id),
sqsMessage.getMessage().getMessage().getReceiptHandle()));
++id;
}
// Send the request
DeleteMessageBatchResult result = client.deleteMessageBatch(batchReqEntries);
if (result.getSuccessful() != null) {
ackSuccess.addAndGet(result.getSuccessful().size());
}
// Handle failed sends
if (result.getFailed() != null && !result.getFailed().isEmpty()) {
toAck = Lists.newArrayListWithCapacity(result.getFailed().size());
for (BatchResultErrorEntry entry : result.getFailed()) {
// There cannot be resent and are probably the result of something like message exceeding
// the max size or certificate errors
if (entry.isSenderFault()) {
ackInvalid.incrementAndGet();
// TODO: messages.get(Integer.parseInt(entry.getId())).setException(new ProducerException(entry.getCode()));
}
// These messages can probably be resent and may be due to issues on the amazon side,
// such as service timeout
else {
ackFailure.incrementAndGet();
toAck.add(messages.get(Integer.parseInt(entry.getId())));
}
}
}
else {
return;
}
} catch (Exception e) {
LOG.error("Error acking messages " + getId(), e);
}
}
}
}