/* * Copyright 2010-2013 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. * A copy of the License is located at * * http://aws.amazon.com/apache2.0 * * or in the "license" file accompanying this file. This file is distributed * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either * express or implied. See the License for the specific language governing * permissions and limitations under the License. */ package com.amazonaws.services.sqs; import static com.amazonaws.util.StringUtils.UTF8; import com.amazonaws.AmazonClientException; import com.amazonaws.Request; import com.amazonaws.handlers.AbstractRequestHandler; import com.amazonaws.services.sqs.model.Message; import com.amazonaws.services.sqs.model.MessageAttributeValue; import com.amazonaws.services.sqs.model.ReceiveMessageRequest; import com.amazonaws.services.sqs.model.ReceiveMessageResult; import com.amazonaws.services.sqs.model.SendMessageBatchRequest; import com.amazonaws.services.sqs.model.SendMessageBatchRequestEntry; import com.amazonaws.services.sqs.model.SendMessageBatchResult; import com.amazonaws.services.sqs.model.SendMessageBatchResultEntry; import com.amazonaws.services.sqs.model.SendMessageRequest; import com.amazonaws.services.sqs.model.SendMessageResult; import com.amazonaws.util.BinaryUtils; import com.amazonaws.util.Md5Utils; import com.amazonaws.util.TimingInfo; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import java.io.UnsupportedEncodingException; import java.nio.ByteBuffer; import java.security.MessageDigest; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; /** * SQS operations on sending and receiving messages will return the MD5 digest * of the message body. This custom request handler will verify that the message * is correctly received by SQS, by comparing the returned MD5 with the * calculation according to the original request. */ public class MessageMD5ChecksumHandler extends AbstractRequestHandler { private static final int INTEGER_SIZE_IN_BYTES = 4; private static final byte STRING_TYPE_FIELD_INDEX = 1; private static final byte BINARY_TYPE_FIELD_INDEX = 2; private static final byte STRING_LIST_TYPE_FIELD_INDEX = 3; private static final byte BINARY_LIST_TYPE_FIELD_INDEX = 4; /* * Constant strings for composing error message. */ private static final String MD5_MISMATCH_ERROR_MESSAGE = "MD5 returned by SQS does not match the calculation on the original request. " + "(MD5 calculated by the %s: \"%s\", MD5 checksum returned: \"%s\")"; private static final String MD5_MISMATCH_ERROR_MESSAGE_WITH_ID = "MD5 returned by SQS does not match the calculation on the original request. " + "(Message ID: %s, MD5 calculated by the %s: \"%s\", MD5 checksum returned: \"%s\")"; private static final String MESSAGE_BODY = "message body"; private static final String MESSAGE_ATTRIBUTES = "message attributes"; private static final Log log = LogFactory.getLog(MessageMD5ChecksumHandler.class); @Override public void afterResponse(Request<?> request, Object response, TimingInfo timingInfo) { if (request != null && response != null) { // SendMessage if (request.getOriginalRequest() instanceof SendMessageRequest && response instanceof SendMessageResult) { SendMessageRequest sendMessageRequest = (SendMessageRequest) request .getOriginalRequest(); SendMessageResult sendMessageResult = (SendMessageResult) response; sendMessageOperationMd5Check(sendMessageRequest, sendMessageResult); } // ReceiveMessage else if (request.getOriginalRequest() instanceof ReceiveMessageRequest && response instanceof ReceiveMessageResult) { ReceiveMessageResult receiveMessageResult = (ReceiveMessageResult) response; receiveMessageResultMd5Check(receiveMessageResult); } // SendMessageBatch else if (request.getOriginalRequest() instanceof SendMessageBatchRequest && response instanceof SendMessageBatchResult) { SendMessageBatchRequest sendMessageBatchRequest = (SendMessageBatchRequest) request .getOriginalRequest(); SendMessageBatchResult sendMessageBatchResult = (SendMessageBatchResult) response; sendMessageBatchOperationMd5Check(sendMessageBatchRequest, sendMessageBatchResult); } } } /** * Throw an exception if the MD5 checksums returned in the SendMessageResult * do not match the client-side calculation based on the original message in * the SendMessageRequest. */ private static void sendMessageOperationMd5Check(SendMessageRequest sendMessageRequest, SendMessageResult sendMessageResult) { String messageBodySent = sendMessageRequest.getMessageBody(); String bodyMd5Returned = sendMessageResult.getMD5OfMessageBody(); String clientSideBodyMd5 = calculateMessageBodyMd5(messageBodySent); if (!clientSideBodyMd5.equals(bodyMd5Returned)) { throw new AmazonClientException(String.format( MD5_MISMATCH_ERROR_MESSAGE, MESSAGE_BODY, clientSideBodyMd5, bodyMd5Returned)); } Map<String, MessageAttributeValue> messageAttrSent = sendMessageRequest .getMessageAttributes(); if (messageAttrSent != null && !messageAttrSent.isEmpty()) { String clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttrSent); String attrMd5Returned = sendMessageResult.getMD5OfMessageAttributes(); if (!clientSideAttrMd5.equals(attrMd5Returned)) { throw new AmazonClientException(String.format( MD5_MISMATCH_ERROR_MESSAGE, MESSAGE_ATTRIBUTES, clientSideAttrMd5, attrMd5Returned)); } } } /** * Throw an exception if the MD5 checksums included in the * ReceiveMessageResult do not match the client-side calculation on the * received messages. */ private static void receiveMessageResultMd5Check(ReceiveMessageResult receiveMessageResult) { if (receiveMessageResult.getMessages() != null) { for (Message messageReceived : receiveMessageResult.getMessages()) { String messageBody = messageReceived.getBody(); String bodyMd5Returned = messageReceived.getMD5OfBody(); String clientSideBodyMd5 = calculateMessageBodyMd5(messageBody); if (!clientSideBodyMd5.equals(bodyMd5Returned)) { throw new AmazonClientException(String.format( MD5_MISMATCH_ERROR_MESSAGE, MESSAGE_BODY, clientSideBodyMd5, bodyMd5Returned)); } Map<String, MessageAttributeValue> messageAttr = messageReceived .getMessageAttributes(); if (messageAttr != null && !messageAttr.isEmpty()) { String attrMd5Returned = messageReceived.getMD5OfMessageAttributes(); String clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttr); if (!clientSideAttrMd5.equals(attrMd5Returned)) { throw new AmazonClientException(String.format( MD5_MISMATCH_ERROR_MESSAGE, MESSAGE_ATTRIBUTES, clientSideAttrMd5, attrMd5Returned)); } } } } } /** * Throw an exception if the MD5 checksums returned in the * SendMessageBatchResult do not match the client-side calculation based on * the original messages in the SendMessageBatchRequest. */ private static void sendMessageBatchOperationMd5Check( SendMessageBatchRequest sendMessageBatchRequest, SendMessageBatchResult sendMessageBatchResult) { Map<String, SendMessageBatchRequestEntry> idToRequestEntryMap = new HashMap<String, SendMessageBatchRequestEntry>(); if (sendMessageBatchRequest.getEntries() != null) { for (SendMessageBatchRequestEntry entry : sendMessageBatchRequest.getEntries()) { idToRequestEntryMap.put(entry.getId(), entry); } } if (sendMessageBatchResult.getSuccessful() != null) { for (SendMessageBatchResultEntry entry : sendMessageBatchResult.getSuccessful()) { String messageBody = idToRequestEntryMap.get(entry.getId()).getMessageBody(); String bodyMd5Returned = entry.getMD5OfMessageBody(); String clientSideBodyMd5 = calculateMessageBodyMd5(messageBody); if (!clientSideBodyMd5.equals(bodyMd5Returned)) { throw new AmazonClientException(String.format( MD5_MISMATCH_ERROR_MESSAGE_WITH_ID, MESSAGE_BODY, entry.getId(), clientSideBodyMd5, bodyMd5Returned)); } Map<String, MessageAttributeValue> messageAttr = idToRequestEntryMap.get( entry.getId()).getMessageAttributes(); if (messageAttr != null && !messageAttr.isEmpty()) { String attrMd5Returned = entry.getMD5OfMessageAttributes(); String clientSideAttrMd5 = calculateMessageAttributesMd5(messageAttr); if (!clientSideAttrMd5.equals(attrMd5Returned)) { throw new AmazonClientException(String.format( MD5_MISMATCH_ERROR_MESSAGE_WITH_ID, MESSAGE_ATTRIBUTES, entry.getId(), clientSideAttrMd5, attrMd5Returned)); } } } } } /** * Returns the hex-encoded MD5 hash String of the given message body. */ private static String calculateMessageBodyMd5(String messageBody) { if (log.isDebugEnabled()) { log.debug("Message body: " + messageBody); } byte[] expectedMd5; try { expectedMd5 = Md5Utils.computeMD5Hash(messageBody.getBytes(UTF8)); } catch (Exception e) { throw new AmazonClientException( "Unable to calculate the MD5 hash of the message body. " + e.getMessage(), e); } String expectedMd5Hex = BinaryUtils.toHex(expectedMd5); if (log.isDebugEnabled()) { log.debug("Expected MD5 of message body: " + expectedMd5Hex); } return expectedMd5Hex; } /** * Returns the hex-encoded MD5 hash String of the given message attributes. */ private static String calculateMessageAttributesMd5( final Map<String, MessageAttributeValue> messageAttributes) { if (log.isDebugEnabled()) { log.debug("Message attribtues: " + messageAttributes); } List<String> sortedAttributeNames = new ArrayList<String>(messageAttributes.keySet()); Collections.sort(sortedAttributeNames); MessageDigest md5Digest = null; try { md5Digest = MessageDigest.getInstance("MD5"); for (String attrName : sortedAttributeNames) { MessageAttributeValue attrValue = messageAttributes.get(attrName); // Encoded Name updateLengthAndBytes(md5Digest, attrName); // Encoded Type updateLengthAndBytes(md5Digest, attrValue.getDataType()); // Encoded Value if (attrValue.getStringValue() != null) { md5Digest.update(STRING_TYPE_FIELD_INDEX); updateLengthAndBytes(md5Digest, attrValue.getStringValue()); } else if (attrValue.getBinaryValue() != null) { md5Digest.update(BINARY_TYPE_FIELD_INDEX); updateLengthAndBytes(md5Digest, attrValue.getBinaryValue()); } else if (attrValue.getStringListValues() != null) { md5Digest.update(STRING_LIST_TYPE_FIELD_INDEX); for (String strListMember : attrValue.getStringListValues()) { updateLengthAndBytes(md5Digest, strListMember); } } else if (attrValue.getBinaryListValues() != null) { md5Digest.update(BINARY_LIST_TYPE_FIELD_INDEX); for (ByteBuffer byteListMember : attrValue.getBinaryListValues()) { updateLengthAndBytes(md5Digest, byteListMember); } } } } catch (Exception e) { throw new AmazonClientException( "Unable to calculate the MD5 hash of the message attributes. " + e.getMessage(), e); } String expectedMd5Hex = BinaryUtils.toHex(md5Digest.digest()); if (log.isDebugEnabled()) { log.debug("Expected MD5 of message attributes: " + expectedMd5Hex); } return expectedMd5Hex; } /** * Update the digest using a sequence of bytes that consists of the length * (in 4 bytes) of the input String and the actual utf8-encoded byte values. */ private static void updateLengthAndBytes(MessageDigest digest, String str) throws UnsupportedEncodingException { byte[] utf8Encoded = str.getBytes(UTF8); ByteBuffer lengthBytes = ByteBuffer.allocate(INTEGER_SIZE_IN_BYTES).putInt( utf8Encoded.length); digest.update(lengthBytes.array()); digest.update(utf8Encoded); } /** * Update the digest using a sequence of bytes that consists of the length * (in 4 bytes) of the input ByteBuffer and all the bytes it contains. */ private static void updateLengthAndBytes(MessageDigest digest, ByteBuffer binaryValue) { // Rewind the ByteBuffer, in case that get/put operations were applied // to // the unmarshalled BB before it's passed to this handler. binaryValue.rewind(); int size = binaryValue.remaining(); ByteBuffer lengthBytes = ByteBuffer.allocate(INTEGER_SIZE_IN_BYTES).putInt(size); digest.update(lengthBytes.array()); digest.update(binaryValue); } }