/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.connectors.rabbitmq;
import com.rabbitmq.client.AMQP;
import com.rabbitmq.client.Channel;
import com.rabbitmq.client.Connection;
import com.rabbitmq.client.ConnectionFactory;
import com.rabbitmq.client.Envelope;
import com.rabbitmq.client.QueueingConsumer;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.state.OperatorStateStore;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.operators.StreamSource;
import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.connectors.rabbitmq.common.RMQConnectionConfig;
import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
import org.apache.flink.streaming.util.serialization.DeserializationSchema;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.powermock.modules.junit4.PowerMockRunner;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
/**
* Tests for the RMQSource. The source supports two operation modes.
* 1) Exactly-once (when checkpointed) with RabbitMQ transactions and the deduplication mechanism in
* {@link org.apache.flink.streaming.api.functions.source.MessageAcknowledgingSourceBase}.
* 2) At-least-once (when checkpointed) with RabbitMQ transactions but not deduplication.
* 3) No strong delivery guarantees (without checkpointing) with RabbitMQ auto-commit mode.
*
* This tests assumes that the message ids are increasing monotonously. That doesn't have to be the
* case. The correlation id is used to uniquely identify messages.
*/
@RunWith(PowerMockRunner.class)
public class RMQSourceTest {
private RMQSource<String> source;
private Configuration config = new Configuration();
private Thread sourceThread;
private volatile long messageId;
private boolean generateCorrelationIds;
private volatile Exception exception;
@Before
public void beforeTest() throws Exception {
OperatorStateStore mockStore = Mockito.mock(OperatorStateStore.class);
FunctionInitializationContext mockContext = Mockito.mock(FunctionInitializationContext.class);
Mockito.when(mockContext.getOperatorStateStore()).thenReturn(mockStore);
Mockito.when(mockStore.getSerializableListState(any(String.class))).thenReturn(null);
source = new RMQTestSource();
source.initializeState(mockContext);
source.open(config);
messageId = 0;
generateCorrelationIds = true;
sourceThread = new Thread(new Runnable() {
@Override
public void run() {
try {
source.run(new DummySourceContext());
} catch (Exception e) {
exception = e;
}
}
});
}
@After
public void afterTest() throws Exception {
source.cancel();
sourceThread.join();
}
@Test
public void throwExceptionIfConnectionFactoryReturnNull() throws Exception {
RMQConnectionConfig connectionConfig = Mockito.mock(RMQConnectionConfig.class);
ConnectionFactory connectionFactory = Mockito.mock(ConnectionFactory.class);
Connection connection = Mockito.mock(Connection.class);
Mockito.when(connectionConfig.getConnectionFactory()).thenReturn(connectionFactory);
Mockito.when(connectionFactory.newConnection()).thenReturn(connection);
Mockito.when(connection.createChannel()).thenReturn(null);
RMQSource<String> rmqSource = new RMQSource<>(
connectionConfig, "queueDummy", true, new StringDeserializationScheme());
try {
rmqSource.open(new Configuration());
} catch (RuntimeException ex) {
assertEquals("None of RabbitMQ channels are available", ex.getMessage());
}
}
@Test
public void testCheckpointing() throws Exception {
source.autoAck = false;
StreamSource<String, RMQSource<String>> src = new StreamSource<>(source);
AbstractStreamOperatorTestHarness<String> testHarness =
new AbstractStreamOperatorTestHarness<>(src, 1, 1, 0);
testHarness.open();
sourceThread.start();
Thread.sleep(5);
final Random random = new Random(System.currentTimeMillis());
int numSnapshots = 50;
long previousSnapshotId;
long lastSnapshotId = 0;
long totalNumberOfAcks = 0;
for (int i=0; i < numSnapshots; i++) {
long snapshotId = random.nextLong();
OperatorStateHandles data;
synchronized (DummySourceContext.lock) {
data = testHarness.snapshot(snapshotId, System.currentTimeMillis());
previousSnapshotId = lastSnapshotId;
lastSnapshotId = messageId;
}
// let some time pass
Thread.sleep(5);
// check if the correct number of messages have been snapshotted
final long numIds = lastSnapshotId - previousSnapshotId;
RMQTestSource sourceCopy = new RMQTestSource();
StreamSource<String, RMQTestSource> srcCopy = new StreamSource<>(sourceCopy);
AbstractStreamOperatorTestHarness<String> testHarnessCopy =
new AbstractStreamOperatorTestHarness<>(srcCopy, 1, 1, 0);
testHarnessCopy.setup();
testHarnessCopy.initializeState(data);
testHarnessCopy.open();
ArrayDeque<Tuple2<Long, List<String>>> deque = sourceCopy.getRestoredState();
List<String> messageIds = deque.getLast().f1;
assertEquals(numIds, messageIds.size());
if (messageIds.size() > 0) {
assertEquals(lastSnapshotId, (long) Long.valueOf(messageIds.get(messageIds.size() - 1)));
}
// check if the messages are being acknowledged and the transaction committed
synchronized (DummySourceContext.lock) {
source.notifyCheckpointComplete(snapshotId);
}
totalNumberOfAcks += numIds;
}
Mockito.verify(source.channel, Mockito.times((int) totalNumberOfAcks)).basicAck(Mockito.anyLong(), Mockito.eq(false));
Mockito.verify(source.channel, Mockito.times(numSnapshots)).txCommit();
}
/**
* Checks whether recurring ids are processed again (they shouldn't be).
*/
@Test
public void testDuplicateId() throws Exception {
source.autoAck = false;
sourceThread.start();
while (messageId < 10) {
// wait until messages have been processed
Thread.sleep(5);
}
long oldMessageId;
synchronized (DummySourceContext.lock) {
oldMessageId = messageId;
messageId = 0;
}
while (messageId < 10) {
// process again
Thread.sleep(5);
}
synchronized (DummySourceContext.lock) {
assertEquals(Math.max(messageId, oldMessageId), DummySourceContext.numElementsCollected);
}
}
/**
* The source should not acknowledge ids in auto-commit mode or check for previously acknowledged ids
*/
@Test
public void testCheckpointingDisabled() throws Exception {
source.autoAck = true;
sourceThread.start();
while (DummySourceContext.numElementsCollected < 50) {
// wait until messages have been processed
Thread.sleep(5);
}
// see addId in RMQTestSource.addId for the assert
}
/**
* Tests error reporting in case of invalid correlation ids
*/
@Test
public void testCorrelationIdNotSet() throws InterruptedException {
generateCorrelationIds = false;
source.autoAck = false;
sourceThread.start();
sourceThread.join();
assertNotNull(exception);
assertTrue(exception instanceof NullPointerException);
}
/**
* Tests whether constructor params are passed correctly.
*/
@Test
public void testConstructorParams() throws Exception {
// verify construction params
RMQConnectionConfig.Builder builder = new RMQConnectionConfig.Builder();
builder.setHost("hostTest").setPort(999).setUserName("userTest").setPassword("passTest").setVirtualHost("/");
ConstructorTestClass testObj = new ConstructorTestClass(
builder.build(), "queueTest", false, new StringDeserializationScheme());
try {
testObj.open(new Configuration());
} catch (Exception e) {
// connection fails but check if args have been passed correctly
}
assertEquals("hostTest", testObj.getFactory().getHost());
assertEquals(999, testObj.getFactory().getPort());
assertEquals("userTest", testObj.getFactory().getUsername());
assertEquals("passTest", testObj.getFactory().getPassword());
}
private static class ConstructorTestClass extends RMQSource<String> {
private ConnectionFactory factory;
public ConstructorTestClass(RMQConnectionConfig rmqConnectionConfig,
String queueName,
boolean usesCorrelationId,
DeserializationSchema<String> deserializationSchema) throws Exception {
super(rmqConnectionConfig, queueName, usesCorrelationId, deserializationSchema);
RMQConnectionConfig.Builder builder = new RMQConnectionConfig.Builder();
builder.setHost("hostTest").setPort(999).setUserName("userTest").setPassword("passTest").setVirtualHost("/");
factory = Mockito.spy(builder.build().getConnectionFactory());
try {
Mockito.doThrow(new RuntimeException()).when(factory).newConnection();
} catch (IOException e) {
fail("Failed to stub connection method");
}
}
@Override
protected ConnectionFactory setupConnectionFactory() {
return factory;
}
public ConnectionFactory getFactory() {
return factory;
}
}
private static class StringDeserializationScheme implements DeserializationSchema<String> {
@Override
public String deserialize(byte[] message) throws IOException {
try {
// wait a bit to not cause too much cpu load
Thread.sleep(1);
} catch (InterruptedException e) {
e.printStackTrace();
}
return new String(message, ConfigConstants.DEFAULT_CHARSET);
}
@Override
public boolean isEndOfStream(String nextElement) {
return false;
}
@Override
public TypeInformation<String> getProducedType() {
return TypeExtractor.getForClass(String.class);
}
}
private class RMQTestSource extends RMQSource<String> {
private ArrayDeque<Tuple2<Long, List<String>>> restoredState;
public RMQTestSource() {
super(new RMQConnectionConfig.Builder().setHost("hostTest")
.setPort(999).setUserName("userTest").setPassword("passTest").setVirtualHost("/").build()
, "queueDummy", true, new StringDeserializationScheme());
}
@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
super.initializeState(context);
this.restoredState = this.pendingCheckpoints;
}
public ArrayDeque<Tuple2<Long, List<String>>> getRestoredState() {
return this.restoredState;
}
@Override
public void open(Configuration config) throws Exception {
super.open(config);
consumer = Mockito.mock(QueueingConsumer.class);
// Mock for delivery
final QueueingConsumer.Delivery deliveryMock = Mockito.mock(QueueingConsumer.Delivery.class);
Mockito.when(deliveryMock.getBody()).thenReturn("test".getBytes(ConfigConstants.DEFAULT_CHARSET));
try {
Mockito.when(consumer.nextDelivery()).thenReturn(deliveryMock);
} catch (InterruptedException e) {
fail("Couldn't setup up deliveryMock");
}
// Mock for envelope
Envelope envelope = Mockito.mock(Envelope.class);
Mockito.when(deliveryMock.getEnvelope()).thenReturn(envelope);
Mockito.when(envelope.getDeliveryTag()).thenAnswer(new Answer<Long>() {
@Override
public Long answer(InvocationOnMock invocation) throws Throwable {
return ++messageId;
}
});
// Mock for properties
AMQP.BasicProperties props = Mockito.mock(AMQP.BasicProperties.class);
Mockito.when(deliveryMock.getProperties()).thenReturn(props);
Mockito.when(props.getCorrelationId()).thenAnswer(new Answer<String>() {
@Override
public String answer(InvocationOnMock invocation) throws Throwable {
return generateCorrelationIds ? "" + messageId : null;
}
});
}
@Override
protected ConnectionFactory setupConnectionFactory() {
ConnectionFactory connectionFactory = Mockito.mock(ConnectionFactory.class);
Connection connection = Mockito.mock(Connection.class);
try {
Mockito.when(connectionFactory.newConnection()).thenReturn(connection);
Mockito.when(connection.createChannel()).thenReturn(Mockito.mock(Channel.class));
} catch (IOException e) {
fail("Test environment couldn't be created.");
}
return connectionFactory;
}
@Override
public RuntimeContext getRuntimeContext() {
return Mockito.mock(StreamingRuntimeContext.class);
}
@Override
protected boolean addId(String uid) {
assertEquals(false, autoAck);
return super.addId(uid);
}
}
private static class DummySourceContext implements SourceFunction.SourceContext<String> {
private static final Object lock = new Object();
private static long numElementsCollected;
public DummySourceContext() {
numElementsCollected = 0;
}
@Override
public void collect(String element) {
numElementsCollected++;
}
@Override
public void collectWithTimestamp(java.lang.String element, long timestamp) {
}
@Override
public void emitWatermark(Watermark mark) {
throw new UnsupportedOperationException();
}
@Override
public void markAsTemporarilyIdle() {
throw new UnsupportedOperationException();
}
@Override
public Object getCheckpointLock() {
return lock;
}
@Override
public void close() {
}
}
}