/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.streaming.api.functions.source; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.accumulators.Accumulator; import org.apache.flink.api.common.io.RichInputFormat; import org.apache.flink.api.common.io.statistics.BaseStatistics; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.io.InputSplit; import org.apache.flink.core.io.InputSplitAssigner; import org.apache.flink.metrics.MetricGroup; import org.apache.flink.metrics.groups.UnregisteredMetricsGroup; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.operators.testutils.MockEnvironment; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; import org.apache.flink.streaming.api.watermark.Watermark; import org.junit.Assert; import org.junit.Test; import java.io.IOException; import java.util.Collections; public class InputFormatSourceFunctionTest { @Test public void testNormalOp() throws Exception { testFormatLifecycle(false); } @Test public void testCancelation() throws Exception { testFormatLifecycle(true); } private void testFormatLifecycle(final boolean midCancel) throws Exception { final int noOfSplits = 5; final int cancelAt = 2; final LifeCycleTestInputFormat format = new LifeCycleTestInputFormat(); final InputFormatSourceFunction<Integer> reader = new InputFormatSourceFunction<>(format, TypeInformation.of(Integer.class)); reader.setRuntimeContext(new MockRuntimeContext(format, noOfSplits)); Assert.assertTrue(!format.isConfigured); Assert.assertTrue(!format.isInputFormatOpen); Assert.assertTrue(!format.isSplitOpen); reader.open(new Configuration()); Assert.assertTrue(format.isConfigured); TestSourceContext ctx = new TestSourceContext(reader, format, midCancel, cancelAt); reader.run(ctx); int splitsSeen = ctx.getSplitsSeen(); Assert.assertTrue(midCancel ? splitsSeen == cancelAt : splitsSeen == noOfSplits); // we have exhausted the splits so the // format and splits should be closed by now Assert.assertTrue(!format.isSplitOpen); Assert.assertTrue(!format.isInputFormatOpen); } private static class LifeCycleTestInputFormat extends RichInputFormat<Integer,InputSplit> { private static final long serialVersionUID = 7408902249499583273L; private boolean isConfigured = false; private boolean isInputFormatOpen = false; private boolean isSplitOpen = false; // end of split private boolean eos = false; private int splitCounter = 0; private int reachedEndCalls = 0; private int nextRecordCalls = 0; @Override public void openInputFormat() { Assert.assertTrue(isConfigured); Assert.assertTrue(!isInputFormatOpen); Assert.assertTrue(!isSplitOpen); this.isInputFormatOpen = true; } @Override public void closeInputFormat() { Assert.assertTrue(!isSplitOpen); this.isInputFormatOpen = false; } @Override public void configure(Configuration parameters) { Assert.assertTrue(!isConfigured); this.isConfigured = true; } @Override public BaseStatistics getStatistics(BaseStatistics cachedStatistics) throws IOException { return null; } @Override public InputSplit[] createInputSplits(int minNumSplits) throws IOException { Assert.assertTrue(isConfigured); InputSplit[] splits = new InputSplit[minNumSplits]; for (int i = 0; i < minNumSplits; i++) { final int idx = i; splits[idx] = new InputSplit() { private static final long serialVersionUID = -1480792932361908285L; @Override public int getSplitNumber() { return idx; } }; } return splits; } @Override public InputSplitAssigner getInputSplitAssigner(InputSplit[] inputSplits) { return null; } @Override public void open(InputSplit split) throws IOException { // whenever a new split opens, // the previous should have been closed Assert.assertTrue(isInputFormatOpen); Assert.assertTrue(isConfigured); Assert.assertTrue(!isSplitOpen); isSplitOpen = true; eos = false; } @Override public boolean reachedEnd() throws IOException { Assert.assertTrue(isInputFormatOpen); Assert.assertTrue(isConfigured); Assert.assertTrue(isSplitOpen); if (!eos) { reachedEndCalls++; } return eos; } @Override public Integer nextRecord(Integer reuse) throws IOException { Assert.assertTrue(isInputFormatOpen); Assert.assertTrue(isConfigured); Assert.assertTrue(isSplitOpen); Assert.assertTrue(reachedEndCalls == ++nextRecordCalls); eos = true; return splitCounter++; } @Override public void close() throws IOException { this.isSplitOpen = false; } } private static class TestSourceContext implements SourceFunction.SourceContext<Integer> { private final InputFormatSourceFunction<Integer> reader; private final LifeCycleTestInputFormat format; private final boolean shouldCancel; private final int cancelAt; int splitIdx = 0; private TestSourceContext(InputFormatSourceFunction<Integer> reader, LifeCycleTestInputFormat format, boolean shouldCancel, int cancelAt) { this.reader = reader; this.format = format; this.shouldCancel = shouldCancel; this.cancelAt = cancelAt; } @Override public void collect(Integer element) { Assert.assertTrue(format.isSplitOpen); Assert.assertTrue(splitIdx == element); if (shouldCancel && splitIdx == cancelAt) { reader.cancel(); } else { splitIdx++; } } @Override public void collectWithTimestamp(Integer element, long timestamp) { throw new UnsupportedOperationException(); } @Override public void emitWatermark(Watermark mark) { throw new UnsupportedOperationException(); } @Override public void markAsTemporarilyIdle() { throw new UnsupportedOperationException(); } @Override public Object getCheckpointLock() { return null; } @Override public void close() { throw new UnsupportedOperationException(); } public int getSplitsSeen() { return this.splitIdx; } } @SuppressWarnings("deprecation") private static class MockRuntimeContext extends StreamingRuntimeContext { private final int noOfSplits; private int nextSplit = 0; private final LifeCycleTestInputFormat format; private InputSplit[] inputSplits; private MockRuntimeContext(LifeCycleTestInputFormat format, int noOfSplits) { super(new MockStreamOperator(), new MockEnvironment("no", 4 * MemoryManager.DEFAULT_PAGE_SIZE, null, 16), Collections.<String, Accumulator<?, ?>>emptyMap()); this.noOfSplits = noOfSplits; this.format = format; } @Override public MetricGroup getMetricGroup() { return new UnregisteredMetricsGroup(); } @Override public InputSplitProvider getInputSplitProvider() { try { this.inputSplits = format.createInputSplits(noOfSplits); Assert.assertTrue(inputSplits.length == noOfSplits); } catch (IOException e) { e.printStackTrace(); } return new InputSplitProvider() { @Override public InputSplit getNextInputSplit(ClassLoader userCodeClassLoader) { if (nextSplit < inputSplits.length) { return inputSplits[nextSplit++]; } return null; } }; } // ------------------------------------------------------------------------ private static class MockStreamOperator extends AbstractStreamOperator<Integer> { private static final long serialVersionUID = -1153976702711944427L; @Override public ExecutionConfig getExecutionConfig() { return new ExecutionConfig(); } } } }