/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.test.checkpointing; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import java.io.IOException; import java.util.Map; import java.util.Map.Entry; import java.util.Random; import java.util.concurrent.ConcurrentHashMap; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; import org.apache.flink.streaming.api.checkpoint.Checkpointed; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; /** * A simple test that runs a streaming topology with checkpointing enabled. * * The test triggers a failure after a while and verifies that, after * completion, the state reflects the "exactly once" semantics. * * It is designed to check partitioned states. */ @SuppressWarnings("serial") public class PartitionedStateCheckpointingITCase extends StreamFaultToleranceTestBase { final long NUM_STRINGS = 10_000_000L; final static int NUM_KEYS = 40; @Override public void testProgram(StreamExecutionEnvironment env) { assertTrue("Broken test setup", (NUM_STRINGS/2) % NUM_KEYS == 0); DataStream<Integer> stream1 = env.addSource(new IntGeneratingSourceFunction(NUM_STRINGS / 2)); DataStream<Integer> stream2 = env.addSource(new IntGeneratingSourceFunction(NUM_STRINGS / 2)); stream1.union(stream2) .keyBy(new IdentityKeySelector<Integer>()) .map(new OnceFailingPartitionedSum(NUM_STRINGS)) .keyBy(0) .addSink(new CounterSink()); } @Override public void postSubmit() { // verify that we counted exactly right for (Entry<Integer, Long> sum : OnceFailingPartitionedSum.allSums.entrySet()) { assertEquals(new Long(sum.getKey() * NUM_STRINGS / NUM_KEYS), sum.getValue()); } for (Long count : CounterSink.allCounts.values()) { assertEquals(new Long(NUM_STRINGS / NUM_KEYS), count); } assertEquals(NUM_KEYS, CounterSink.allCounts.size()); assertEquals(NUM_KEYS, OnceFailingPartitionedSum.allSums.size()); } // -------------------------------------------------------------------------------------------- // Custom Functions // -------------------------------------------------------------------------------------------- private static class IntGeneratingSourceFunction extends RichParallelSourceFunction<Integer> implements Checkpointed<Integer> { private final long numElements; private int index; private int step; private volatile boolean isRunning = true; static final long[] counts = new long[PARALLELISM]; @Override public void close() throws IOException { counts[getRuntimeContext().getIndexOfThisSubtask()] = index; } IntGeneratingSourceFunction(long numElements) { this.numElements = numElements; } @Override public void open(Configuration parameters) throws IOException { step = getRuntimeContext().getNumberOfParallelSubtasks(); if (index == 0) { index = getRuntimeContext().getIndexOfThisSubtask(); } } @Override public void run(SourceContext<Integer> ctx) throws Exception { final Object lockingObject = ctx.getCheckpointLock(); while (isRunning && index < numElements) { synchronized (lockingObject) { index += step; ctx.collect(index % NUM_KEYS); } } } @Override public void cancel() { isRunning = false; } @Override public Integer snapshotState(long checkpointId, long checkpointTimestamp) { return index; } @Override public void restoreState(Integer state) { index = state; } } private static class OnceFailingPartitionedSum extends RichMapFunction<Integer, Tuple2<Integer, Long>> { private static Map<Integer, Long> allSums = new ConcurrentHashMap<Integer, Long>(); private static volatile boolean hasFailed = false; private final long numElements; private long failurePos; private long count; private ValueState<Long> sum; OnceFailingPartitionedSum(long numElements) { this.numElements = numElements; } @Override public void open(Configuration parameters) throws IOException { long failurePosMin = (long) (0.6 * numElements / getRuntimeContext() .getNumberOfParallelSubtasks()); long failurePosMax = (long) (0.8 * numElements / getRuntimeContext() .getNumberOfParallelSubtasks()); failurePos = (new Random().nextLong() % (failurePosMax - failurePosMin)) + failurePosMin; count = 0; sum = getRuntimeContext().getState( new ValueStateDescriptor<>("my_state", Long.class, 0L)); } @Override public Tuple2<Integer, Long> map(Integer value) throws Exception { count++; if (!hasFailed && count >= failurePos) { hasFailed = true; throw new Exception("Test Failure"); } long currentSum = sum.value() + value; sum.update(currentSum); allSums.put(value, currentSum); return new Tuple2<Integer, Long>(value, currentSum); } } private static class CounterSink extends RichSinkFunction<Tuple2<Integer, Long>> { private static Map<Integer, Long> allCounts = new ConcurrentHashMap<Integer, Long>(); private ValueState<NonSerializableLong> aCounts; private ValueState<Long> bCounts; @Override public void open(Configuration parameters) throws IOException { aCounts = getRuntimeContext().getState( new ValueStateDescriptor<>("a", NonSerializableLong.class, NonSerializableLong.of(0L))); bCounts = getRuntimeContext().getState( new ValueStateDescriptor<>("b", Long.class, 0L)); } @Override public void invoke(Tuple2<Integer, Long> value) throws Exception { long ac = aCounts.value().value; long bc = bCounts.value(); assertEquals(ac, bc); long currentCount = ac + 1; aCounts.update(NonSerializableLong.of(currentCount)); bCounts.update(currentCount); allCounts.put(value.f0, currentCount); } } public static class NonSerializableLong { public Long value; private NonSerializableLong(long value) { this.value = value; } public static NonSerializableLong of(long value) { return new NonSerializableLong(value); } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; NonSerializableLong that = (NonSerializableLong) o; return value.equals(that.value); } @Override public int hashCode() { return value.hashCode(); } } public static class IdentityKeySelector<T> implements KeySelector<T, T> { @Override public T getKey(T value) throws Exception { return value; } } }