/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.api.functions;
import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.functions.source.StatefulSequenceSource;
import org.apache.flink.streaming.api.operators.StreamSource;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
import org.junit.Assert;
import org.junit.Test;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
public class StatefulSequenceSourceTest {
@Test
public void testCheckpointRestore() throws Exception {
final int initElement = 0;
final int maxElement = 100;
final Set<Long> expectedOutput = new HashSet<>();
for (long i = initElement; i <= maxElement; i++) {
expectedOutput.add(i);
}
final ConcurrentHashMap<String, List<Long>> outputCollector = new ConcurrentHashMap<>();
final OneShotLatch latchToTrigger1 = new OneShotLatch();
final OneShotLatch latchToWait1 = new OneShotLatch();
final OneShotLatch latchToTrigger2 = new OneShotLatch();
final OneShotLatch latchToWait2 = new OneShotLatch();
final StatefulSequenceSource source1 = new StatefulSequenceSource(initElement, maxElement);
StreamSource<Long, StatefulSequenceSource> src1 = new StreamSource<>(source1);
final AbstractStreamOperatorTestHarness<Long> testHarness1 =
new AbstractStreamOperatorTestHarness<>(src1, 2, 2, 0);
testHarness1.open();
final StatefulSequenceSource source2 = new StatefulSequenceSource(initElement, maxElement);
StreamSource<Long, StatefulSequenceSource> src2 = new StreamSource<>(source2);
final AbstractStreamOperatorTestHarness<Long> testHarness2 =
new AbstractStreamOperatorTestHarness<>(src2, 2, 2, 1);
testHarness2.open();
final Throwable[] error = new Throwable[3];
// run the source asynchronously
Thread runner1 = new Thread() {
@Override
public void run() {
try {
source1.run(new BlockingSourceContext("1", latchToTrigger1, latchToWait1, outputCollector, 21));
}
catch (Throwable t) {
t.printStackTrace();
error[0] = t;
}
}
};
// run the source asynchronously
Thread runner2 = new Thread() {
@Override
public void run() {
try {
source2.run(new BlockingSourceContext("2", latchToTrigger2, latchToWait2, outputCollector, 32));
}
catch (Throwable t) {
t.printStackTrace();
error[1] = t;
}
}
};
runner1.start();
runner2.start();
if (!latchToTrigger1.isTriggered()) {
latchToTrigger1.await();
}
if (!latchToTrigger2.isTriggered()) {
latchToTrigger2.await();
}
OperatorStateHandles snapshot = AbstractStreamOperatorTestHarness.repackageState(
testHarness1.snapshot(0L, 0L),
testHarness2.snapshot(0L, 0L)
);
final StatefulSequenceSource source3 = new StatefulSequenceSource(initElement, maxElement);
StreamSource<Long, StatefulSequenceSource> src3 = new StreamSource<>(source3);
final AbstractStreamOperatorTestHarness<Long> testHarness3 =
new AbstractStreamOperatorTestHarness<>(src3, 2, 1, 0);
testHarness3.setup();
testHarness3.initializeState(snapshot);
testHarness3.open();
final OneShotLatch latchToTrigger3 = new OneShotLatch();
final OneShotLatch latchToWait3 = new OneShotLatch();
latchToWait3.trigger();
// run the source asynchronously
Thread runner3 = new Thread() {
@Override
public void run() {
try {
source3.run(new BlockingSourceContext("3", latchToTrigger3, latchToWait3, outputCollector, 3));
}
catch (Throwable t) {
t.printStackTrace();
error[2] = t;
}
}
};
runner3.start();
runner3.join();
Assert.assertEquals(3, outputCollector.size()); // we have 3 tasks.
// test for at-most-once
Set<Long> dedupRes = new HashSet<>(Math.abs(maxElement - initElement) + 1);
for (Map.Entry<String, List<Long>> elementsPerTask: outputCollector.entrySet()) {
String key = elementsPerTask.getKey();
List<Long> elements = outputCollector.get(key);
// this tests the correctness of the latches in the test
Assert.assertTrue(elements.size() > 0);
for (Long elem : elements) {
if (!dedupRes.add(elem)) {
Assert.fail("Duplicate entry: " + elem);
}
if (!expectedOutput.contains(elem)) {
Assert.fail("Unexpected element: " + elem);
}
}
}
// test for exactly-once
Assert.assertEquals(Math.abs(initElement - maxElement) + 1, dedupRes.size());
latchToWait1.trigger();
latchToWait2.trigger();
// wait for everybody ot finish.
runner1.join();
runner2.join();
}
private static class BlockingSourceContext implements SourceFunction.SourceContext<Long> {
private final String name;
private final Object lock;
private final OneShotLatch latchToTrigger;
private final OneShotLatch latchToWait;
private final ConcurrentHashMap<String, List<Long>> collector;
private final int threshold;
private int counter = 0;
private final List<Long> localOutput;
public BlockingSourceContext(String name, OneShotLatch latchToTrigger, OneShotLatch latchToWait,
ConcurrentHashMap<String, List<Long>> output, int elemToFire) {
this.name = name;
this.lock = new Object();
this.latchToTrigger = latchToTrigger;
this.latchToWait = latchToWait;
this.collector = output;
this.threshold = elemToFire;
this.localOutput = new ArrayList<>();
List<Long> prev = collector.put(name, localOutput);
if (prev != null) {
Assert.fail();
}
}
@Override
public void collectWithTimestamp(Long element, long timestamp) {
collect(element);
}
@Override
public void collect(Long element) {
localOutput.add(element);
if (++counter == threshold) {
latchToTrigger.trigger();
try {
if (!latchToWait.isTriggered()) {
latchToWait.await();
}
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
@Override
public void emitWatermark(Watermark mark) {
throw new UnsupportedOperationException();
}
@Override
public void markAsTemporarilyIdle() {
throw new UnsupportedOperationException();
}
@Override
public Object getCheckpointLock() {
return lock;
}
@Override
public void close() {
}
}
}