package tw.com.providers;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import org.apache.commons.cli.MissingArgumentException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.amazonaws.auth.policy.Action;
import com.amazonaws.auth.policy.Condition;
import com.amazonaws.auth.policy.Policy;
import com.amazonaws.auth.policy.Principal;
import com.amazonaws.auth.policy.Resource;
import com.amazonaws.auth.policy.Statement;
import com.amazonaws.auth.policy.actions.SQSActions;
import com.amazonaws.auth.policy.conditions.ConditionFactory;
import com.amazonaws.services.sqs.AmazonSQSClient;
import com.amazonaws.services.sqs.model.GetQueueAttributesRequest;
import com.amazonaws.services.sqs.model.GetQueueAttributesResult;
import com.amazonaws.services.sqs.model.SetQueueAttributesRequest;
public class QueuePolicyManager {
private static final Logger logger = LoggerFactory.getLogger(QueuePolicyManager.class);
private static final String QUEUE_POLICY_KEY = "Policy";
public static final String QUEUE_ARN_KEY = "QueueArn";
protected AmazonSQSClient sqsClient;
private Collection<String> attributeNames = new LinkedList<String>();
public QueuePolicyManager(AmazonSQSClient sqsClient) {
this.sqsClient = sqsClient;
attributeNames.add(QUEUE_ARN_KEY);
attributeNames.add(QUEUE_POLICY_KEY);
}
public void checkOrCreateQueuePermissions(
Map<String, String> queueAttributes, String topicSnsArn, String queueArn, String queueURL) {
Policy policy = extractPolicy(queueAttributes);
if (policy!=null) {
logger.info("Policy found for queue, check if required conditions set");
for (Statement statement : policy.getStatements()) {
if (allowQueuePublish(statement)) {
logger.info("Statement allows sending, checking for ARN condition. Statement ID is " + statement.getId());
for (Condition condition : statement.getConditions()) {
if (condition.getConditionKey().equals("aws:SourceArn") &&
condition.getValues().contains(topicSnsArn)) {
logger.info("Found a matching condition for sns arn " + topicSnsArn);
return;
}
}
}
}
}
logger.info("Policy allowing SNS to publish to queue not found");
setQueuePolicy(topicSnsArn, queueArn, queueURL);
}
private void setQueuePolicy(String topicSnsArn, String queueArn, String queueURL) {
logger.info("Set up policy for queue to allow SNS to publish to it");
Policy sqsPolicy = new Policy()
.withStatements(new Statement(Statement.Effect.Allow)
.withPrincipals(Principal.AllUsers)
.withResources(new Resource(queueArn))
.withConditions(ConditionFactory.newSourceArnCondition(topicSnsArn))
.withActions(SQSActions.SendMessage));
Map<String, String> attributes = new HashMap<String,String>();
attributes.put("Policy", sqsPolicy.toJson());
SetQueueAttributesRequest setQueueAttributesRequest = new SetQueueAttributesRequest();
setQueueAttributesRequest.setQueueUrl(queueURL);
setQueueAttributesRequest.setAttributes(attributes);
sqsClient.setQueueAttributes(setQueueAttributesRequest);
}
private boolean allowQueuePublish(Statement statement) {
if (statement.getEffect().equals(Statement.Effect.Allow)) {
List<Action> actions = statement.getActions();
for(Action action : actions) { // .equals not properly defined on actions
if (action.getActionName().equals("sqs:"+SQSActions.SendMessage.toString())) {
return true;
}
}
}
return false;
}
public Map<String, String> getQueueAttributes(String url) throws MissingArgumentException {
// find the queue arn, we need this to create the SNS subscription
GetQueueAttributesRequest getQueueAttributesRequest = new GetQueueAttributesRequest(url);
getQueueAttributesRequest.setAttributeNames(attributeNames);
GetQueueAttributesResult attribResult = sqsClient.getQueueAttributes(getQueueAttributesRequest);
Map<String, String> attribMap = attribResult.getAttributes();
if (!attribMap.containsKey(QUEUE_ARN_KEY)) {
String msg = "Missing arn attirbute, tried attribute with name: " + QUEUE_ARN_KEY;
logger.error(msg);
throw new MissingArgumentException(msg);
}
return attribMap;
}
private Policy extractPolicy(Map<String, String> queueAttributes) {
String policyJson = queueAttributes.get(QUEUE_POLICY_KEY);
if (policyJson==null) {
return null;
}
logger.debug("Current queue policy: " + policyJson);
Policy policy = Policy.fromJson(policyJson);
return policy;
}
}