package com.netflix.suro.input.remotefile; import com.amazonaws.auth.AWSCredentials; import com.amazonaws.auth.AWSCredentialsProvider; import com.amazonaws.util.StringInputStream; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableMap; import com.netflix.suro.input.RecordParser; import com.netflix.suro.input.SuroInput; import com.netflix.suro.jackson.DefaultObjectMapper; import com.netflix.suro.message.MessageContainer; import com.netflix.suro.routing.MessageRouter; import com.netflix.suro.sink.notice.Notice; import com.netflix.util.Pair; import org.jets3t.service.impl.rest.httpclient.RestS3Service; import org.jets3t.service.model.S3Object; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import java.io.File; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ConcurrentSkipListSet; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.assertEquals; import static org.mockito.Matchers.anyString; import static org.mockito.Mockito.*; public class TestS3Consumer { @Rule public TemporaryFolder tempDir = new TemporaryFolder(); private ObjectMapper jsonMapper = new DefaultObjectMapper(); private final int testFileCount = 6; @Test public void test() throws Exception { final String downloadPath = tempDir.newFolder().getAbsolutePath(); final CountDownLatch latch = new CountDownLatch(1); final ConcurrentSkipListSet<String> removedKeys = new ConcurrentSkipListSet<String>(); final AtomicInteger count = new AtomicInteger(0); final AtomicInteger peekedMsgCount = new AtomicInteger(0); final AtomicInteger invalidMsgCount = new AtomicInteger(0); Notice<String> mockedNotice = new Notice<String>() { @Override public void init() { } @Override public boolean send(String message) { return false; } @Override public String recv() { return null; } @Override public Pair<String, String> peek() { if (peekedMsgCount.get() == 1) { // return invalid msg invalidMsgCount.incrementAndGet(); return new Pair<String, String>("receiptHandle" + peekedMsgCount.getAndIncrement(), "invalid_msg"); } if (peekedMsgCount.get() == 3) { // return invalid msg invalidMsgCount.incrementAndGet(); return new Pair<String, String>("receiptHandle" + peekedMsgCount.getAndIncrement(), "{\n" + " \"Message\": {\n" + " \"s3Bucket\": \"bucket\",\n" + " \"s3ObjectKey\": \"key\"\n" + " }\n" + "}"); } if (peekedMsgCount.get() == 5) { // return invalid msg invalidMsgCount.incrementAndGet(); return new Pair<String, String>("receiptHandle" + peekedMsgCount.getAndIncrement(), "{\n" + " \"Message\": {\n" + " \"Bucket\": \"bucket\",\n" + " \"ObjectKey\": [\"key\"]\n" + " }\n" + "}"); } try { List<String> dummyKeys = new ArrayList<String>(); dummyKeys.add("prefix/key" + (count.getAndIncrement())); dummyKeys.add("prefix/key" + (count.getAndIncrement())); return new Pair<String, String>( "receiptHandle" + peekedMsgCount.getAndIncrement(), jsonMapper.writeValueAsString( new ImmutableMap.Builder<String, Object>() .put("Message", new ImmutableMap.Builder<String, Object>() .put("s3Bucket", "bucket") .put("s3ObjectKey", dummyKeys) .build()) .build())); } catch (JsonProcessingException e) { throw new RuntimeException(e); } finally { if (count.get() == testFileCount) { latch.countDown(); } } } @Override public void remove(String key) { removedKeys.add(key); } @Override public String getStat() { return null; } }; AWSCredentialsProvider awsCredentials = mock(AWSCredentialsProvider.class); AWSCredentials credentials = mock(AWSCredentials.class); doReturn("accessKey").when(credentials).getAWSAccessKeyId(); doReturn("secretKey").when(credentials).getAWSSecretKey(); doReturn(credentials).when(awsCredentials).getCredentials(); MessageRouter router = mock(MessageRouter.class); int numOfLines = 3; final StringBuilder sb = new StringBuilder(); for (int i = 0; i < numOfLines; ++i) { sb.append("line" + i).append('\n'); } RestS3Service s3 = mock(RestS3Service.class); doAnswer(new Answer<S3Object>() { @Override public S3Object answer(InvocationOnMock invocation) throws Throwable { S3Object obj = mock(S3Object.class); doReturn(new StringInputStream(sb.toString())).when(obj).getDataInputStream(); return obj; } }).when(s3).getObject(anyString(), anyString()); RecordParser recordParser = mock(RecordParser.class); List<MessageContainer> messages = new ArrayList<MessageContainer>(); int numOfMessages = 5; for (int i = 0; i < numOfMessages; ++i) { messages.add(mock(MessageContainer.class)); } doReturn(messages).when(recordParser).parse(anyString()); S3Consumer consumer = new S3Consumer( "id", "s3Endpoint", mockedNotice, 1000, 3, downloadPath, recordParser, awsCredentials, router, jsonMapper, s3); consumer.start(); latch.await(); consumer.shutdown(); verify(router, times(numOfMessages * numOfLines * count.get())).process(any(SuroInput.class), any(MessageContainer.class)); assertEquals(removedKeys.size(), peekedMsgCount.get() - invalidMsgCount.get()); // no files under downloadPath assertEquals(new File(downloadPath).list().length, 0); } }