/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.connectors.kinesis.testutils;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
import org.apache.flink.streaming.connectors.kinesis.proxy.KinesisProxyInterface;
import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema;
import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchemaWrapper;
import org.apache.flink.streaming.util.serialization.SimpleStringSchema;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Properties;
import java.util.concurrent.atomic.AtomicReference;
public class TestableKinesisDataFetcher extends KinesisDataFetcher<String> {
private static final Object fakeCheckpointLock = new Object();
private long numElementsCollected;
private OneShotLatch runWaiter;
public TestableKinesisDataFetcher(List<String> fakeStreams,
Properties fakeConfiguration,
int fakeTotalCountOfSubtasks,
int fakeTndexOfThisSubtask,
AtomicReference<Throwable> thrownErrorUnderTest,
LinkedList<KinesisStreamShardState> subscribedShardsStateUnderTest,
HashMap<String, String> subscribedStreamsToLastDiscoveredShardIdsStateUnderTest,
KinesisProxyInterface fakeKinesis) {
super(fakeStreams,
getMockedSourceContext(),
fakeCheckpointLock,
getMockedRuntimeContext(fakeTotalCountOfSubtasks, fakeTndexOfThisSubtask),
fakeConfiguration,
new KinesisDeserializationSchemaWrapper<>(new SimpleStringSchema()),
thrownErrorUnderTest,
subscribedShardsStateUnderTest,
subscribedStreamsToLastDiscoveredShardIdsStateUnderTest,
fakeKinesis);
this.numElementsCollected = 0;
this.runWaiter = new OneShotLatch();
}
public long getNumOfElementsCollected() {
return numElementsCollected;
}
@Override
protected KinesisDeserializationSchema<String> getClonedDeserializationSchema() {
return new KinesisDeserializationSchemaWrapper<>(new SimpleStringSchema());
}
@Override
protected void emitRecordAndUpdateState(String record, long recordTimestamp, int shardStateIndex, SequenceNumber lastSequenceNumber) {
synchronized (fakeCheckpointLock) {
this.numElementsCollected++;
updateState(shardStateIndex, lastSequenceNumber);
}
}
@Override
public void runFetcher() throws Exception {
runWaiter.trigger();
super.runFetcher();
}
public void waitUntilRun() throws Exception {
runWaiter.await();
}
@SuppressWarnings("unchecked")
private static SourceFunction.SourceContext<String> getMockedSourceContext() {
return Mockito.mock(SourceFunction.SourceContext.class);
}
private static RuntimeContext getMockedRuntimeContext(final int fakeTotalCountOfSubtasks, final int fakeTndexOfThisSubtask) {
RuntimeContext mockedRuntimeContext = Mockito.mock(RuntimeContext.class);
Mockito.when(mockedRuntimeContext.getNumberOfParallelSubtasks()).thenAnswer(new Answer<Integer>() {
@Override
public Integer answer(InvocationOnMock invocationOnMock) throws Throwable {
return fakeTotalCountOfSubtasks;
}
});
Mockito.when(mockedRuntimeContext.getIndexOfThisSubtask()).thenAnswer(new Answer<Integer>() {
@Override
public Integer answer(InvocationOnMock invocationOnMock) throws Throwable {
return fakeTndexOfThisSubtask;
}
});
Mockito.when(mockedRuntimeContext.getTaskName()).thenAnswer(new Answer<String>() {
@Override
public String answer(InvocationOnMock invocationOnMock) throws Throwable {
return "Fake Task";
}
});
Mockito.when(mockedRuntimeContext.getTaskNameWithSubtasks()).thenAnswer(new Answer<String>() {
@Override
public String answer(InvocationOnMock invocationOnMock) throws Throwable {
return "Fake Task (" + fakeTndexOfThisSubtask + "/" + fakeTotalCountOfSubtasks + ")";
}
});
return mockedRuntimeContext;
}
}