/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.runtime.io.network.partition.consumer; import com.google.common.collect.Lists; import org.apache.flink.api.common.JobID; import org.apache.flink.core.memory.MemoryType; import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.buffer.BufferProvider; import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; import org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener; import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException; import org.apache.flink.runtime.io.network.partition.ResultPartition; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView; import org.apache.flink.runtime.io.network.util.TestBufferFactory; import org.apache.flink.runtime.io.network.util.TestPartitionProducer; import org.apache.flink.runtime.io.network.util.TestProducerSource; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; import org.apache.flink.runtime.taskmanager.TaskActions; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import scala.Tuple2; import java.io.IOException; import java.util.Collections; import java.util.List; import java.util.Timer; import java.util.TimerTask; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import static org.apache.flink.util.Preconditions.checkArgument; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyLong; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class LocalInputChannelTest { /** * Tests the consumption of multiple subpartitions via local input channels. * * <p> Multiple producer tasks produce pipelined partitions, which are consumed by multiple * tasks via local input channels. */ @Test public void testConcurrentConsumeMultiplePartitions() throws Exception { // Config final int parallelism = 32; final int producerBufferPoolSize = parallelism + 1; final int numberOfBuffersPerChannel = 1024; checkArgument(parallelism >= 1); checkArgument(producerBufferPoolSize >= parallelism); checkArgument(numberOfBuffersPerChannel >= 1); // Setup // One thread per produced partition and one per consumer final ExecutorService executor = Executors.newFixedThreadPool(2 * parallelism); final NetworkBufferPool networkBuffers = new NetworkBufferPool( (parallelism * producerBufferPoolSize) + (parallelism * parallelism), TestBufferFactory.BUFFER_SIZE, MemoryType.HEAP); final ResultPartitionConsumableNotifier partitionConsumableNotifier = mock(ResultPartitionConsumableNotifier.class); final TaskActions taskActions = mock(TaskActions.class); final IOManager ioManager = mock(IOManager.class); final JobID jobId = new JobID(); final ResultPartitionManager partitionManager = new ResultPartitionManager(); final ResultPartitionID[] partitionIds = new ResultPartitionID[parallelism]; final TestPartitionProducer[] partitionProducers = new TestPartitionProducer[parallelism]; // Create all partitions for (int i = 0; i < parallelism; i++) { partitionIds[i] = new ResultPartitionID(); final ResultPartition partition = new ResultPartition( "Test Name", taskActions, jobId, partitionIds[i], ResultPartitionType.PIPELINED, parallelism, parallelism, partitionManager, partitionConsumableNotifier, ioManager, true); // Create a buffer pool for this partition partition.registerBufferPool( networkBuffers.createBufferPool(producerBufferPoolSize, producerBufferPoolSize)); // Create the producer partitionProducers[i] = new TestPartitionProducer( partition, false, new TestPartitionProducerBufferSource( parallelism, partition.getBufferProvider(), numberOfBuffersPerChannel) ); // Register with the partition manager in order to allow the local input channels to // request their respective partitions. partitionManager.registerResultPartition(partition); } // Test try { // Submit producer tasks List<Future<?>> results = Lists.newArrayListWithCapacity( parallelism + 1); for (int i = 0; i < parallelism; i++) { results.add(executor.submit(partitionProducers[i])); } // Submit consumer for (int i = 0; i < parallelism; i++) { results.add(executor.submit( new TestLocalInputChannelConsumer( i, parallelism, numberOfBuffersPerChannel, networkBuffers.createBufferPool(parallelism, parallelism), partitionManager, new TaskEventDispatcher(), partitionIds))); } // Wait for all to finish for (Future<?> result : results) { result.get(); } } finally { networkBuffers.destroy(); executor.shutdown(); } } @Test public void testPartitionRequestExponentialBackoff() throws Exception { // Config Tuple2<Integer, Integer> backoff = new Tuple2<>(500, 3000); // Start with initial backoff, then keep doubling, and cap at max. int[] expectedDelays = {backoff._1(), 1000, 2000, backoff._2()}; // Setup SingleInputGate inputGate = mock(SingleInputGate.class); BufferProvider bufferProvider = mock(BufferProvider.class); when(inputGate.getBufferProvider()).thenReturn(bufferProvider); ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); LocalInputChannel ch = createLocalInputChannel(inputGate, partitionManager, backoff); when(partitionManager .createSubpartitionView(eq(ch.partitionId), eq(0), any(BufferAvailabilityListener.class))) .thenThrow(new PartitionNotFoundException(ch.partitionId)); Timer timer = mock(Timer.class); doAnswer(new Answer<Void>() { @Override public Void answer(InvocationOnMock invocation) throws Throwable { ((TimerTask) invocation.getArguments()[0]).run(); return null; } }).when(timer).schedule(any(TimerTask.class), anyLong()); // Initial request ch.requestSubpartition(0); verify(partitionManager) .createSubpartitionView(eq(ch.partitionId), eq(0), any(BufferAvailabilityListener.class)); // Request subpartition and verify that the actual requests are delayed. for (long expected : expectedDelays) { ch.retriggerSubpartitionRequest(timer, 0); verify(timer).schedule(any(TimerTask.class), eq(expected)); } // Exception after backoff is greater than the maximum backoff. try { ch.retriggerSubpartitionRequest(timer, 0); ch.getNextBuffer(); fail("Did not throw expected exception."); } catch (Exception expected) { } } @Test(expected = CancelTaskException.class) public void testProducerFailedException() throws Exception { ResultSubpartitionView view = mock(ResultSubpartitionView.class); when(view.isReleased()).thenReturn(true); when(view.getFailureCause()).thenReturn(new Exception("Expected test exception")); ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); when(partitionManager .createSubpartitionView(any(ResultPartitionID.class), anyInt(), any(BufferAvailabilityListener.class))) .thenReturn(view); SingleInputGate inputGate = mock(SingleInputGate.class); BufferProvider bufferProvider = mock(BufferProvider.class); when(inputGate.getBufferProvider()).thenReturn(bufferProvider); LocalInputChannel ch = createLocalInputChannel( inputGate, partitionManager, new Tuple2<>(0, 0)); ch.requestSubpartition(0); // Should throw an instance of CancelTaskException. ch.getNextBuffer(); } /** * Verifies that concurrent release via the SingleInputGate and re-triggering * of a partition request works smoothly. * * - SingleInputGate acquires its request lock and tries to release all * registered channels. When releasing a channel, it needs to acquire * the channel's shared request-release lock. * - If a LocalInputChannel concurrently retriggers a partition request via * a Timer Thread it acquires the channel's request-release lock and calls * the retrigger callback on the SingleInputGate, which again tries to * acquire the gate's request lock. * * For certain timings this obviously leads to a deadlock. This test reliably * reproduced such a timing (reported in FLINK-5228). This test is pretty much * testing the buggy implementation and has not much more general value. If it * becomes obsolete at some point (future greatness ;)), feel free to remove it. * * The fix in the end was to to not acquire the channels lock when releasing it * and/or not doing any input gate callbacks while holding the channel's lock. * I decided to do both. */ @Test public void testConcurrentReleaseAndRetriggerPartitionRequest() throws Exception { final SingleInputGate gate = new SingleInputGate( "test task name", new JobID(), new IntermediateDataSetID(), ResultPartitionType.PIPELINED, 0, 1, mock(TaskActions.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup() ); ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); when(partitionManager .createSubpartitionView( any(ResultPartitionID.class), anyInt(), any(BufferAvailabilityListener.class))) .thenAnswer(new Answer<ResultSubpartitionView>() { @Override public ResultSubpartitionView answer(InvocationOnMock invocationOnMock) throws Throwable { // Sleep here a little to give the releaser Thread // time to acquire the input gate lock. We throw // the Exception to retrigger the request. Thread.sleep(100); throw new PartitionNotFoundException(new ResultPartitionID()); } }); final LocalInputChannel channel = new LocalInputChannel( gate, 0, new ResultPartitionID(), partitionManager, new TaskEventDispatcher(), 1, 1, new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup()); gate.setInputChannel(new IntermediateResultPartitionID(), channel); Thread releaser = new Thread() { @Override public void run() { try { gate.releaseAllResources(); } catch (IOException ignored) { } } }; Thread requester = new Thread() { @Override public void run() { try { channel.requestSubpartition(0); } catch (IOException | InterruptedException ignored) { } } }; requester.start(); releaser.start(); releaser.join(); requester.join(); } /** * Tests that reading from a channel when after the partition has been * released are handled and don't lead to NPEs. */ @Test public void testGetNextAfterPartitionReleased() throws Exception { ResultSubpartitionView reader = mock(ResultSubpartitionView.class); SingleInputGate gate = mock(SingleInputGate.class); ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); when(partitionManager.createSubpartitionView( any(ResultPartitionID.class), anyInt(), any(BufferAvailabilityListener.class))).thenReturn(reader); LocalInputChannel channel = new LocalInputChannel( gate, 0, new ResultPartitionID(), partitionManager, new TaskEventDispatcher(), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup()); channel.requestSubpartition(0); // Null buffer but not released when(reader.getNextBuffer()).thenReturn(null); when(reader.isReleased()).thenReturn(false); try { channel.getNextBuffer(); fail("Did not throw expected IllegalStateException"); } catch (IllegalStateException ignored) { } // Null buffer and released when(reader.getNextBuffer()).thenReturn(null); when(reader.isReleased()).thenReturn(true); try { channel.getNextBuffer(); fail("Did not throw expected CancelTaskException"); } catch (CancelTaskException ignored) { } } // --------------------------------------------------------------------------------------------- private LocalInputChannel createLocalInputChannel( SingleInputGate inputGate, ResultPartitionManager partitionManager, Tuple2<Integer, Integer> initialAndMaxRequestBackoff) throws IOException, InterruptedException { return new LocalInputChannel( inputGate, 0, new ResultPartitionID(), partitionManager, mock(TaskEventDispatcher.class), initialAndMaxRequestBackoff._1(), initialAndMaxRequestBackoff._2(), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup()); } /** * Returns the configured number of buffers for each channel in a random order. */ private static class TestPartitionProducerBufferSource implements TestProducerSource { private final BufferProvider bufferProvider; private final List<Byte> channelIndexes; public TestPartitionProducerBufferSource( int parallelism, BufferProvider bufferProvider, int numberOfBuffersToProduce) { this.bufferProvider = bufferProvider; this.channelIndexes = Lists.newArrayListWithCapacity( parallelism * numberOfBuffersToProduce); // Array of channel indexes to produce buffers for for (byte i = 0; i < parallelism; i++) { for (int j = 0; j < numberOfBuffersToProduce; j++) { channelIndexes.add(i); } } // Random buffer to channel ordering Collections.shuffle(channelIndexes); } @Override public BufferOrEvent getNextBufferOrEvent() throws Exception { if (channelIndexes.size() > 0) { final int channelIndex = channelIndexes.remove(0); return new BufferOrEvent(bufferProvider.requestBufferBlocking(), channelIndex); } return null; } } /** * Consumed the configured result partitions and verifies that each channel receives the * expected number of buffers. */ private static class TestLocalInputChannelConsumer implements Callable<Void> { private final SingleInputGate inputGate; private final int numberOfInputChannels; private final int numberOfExpectedBuffersPerChannel; public TestLocalInputChannelConsumer( int subpartitionIndex, int numberOfInputChannels, int numberOfExpectedBuffersPerChannel, BufferPool bufferPool, ResultPartitionManager partitionManager, TaskEventDispatcher taskEventDispatcher, ResultPartitionID[] consumedPartitionIds) { checkArgument(numberOfInputChannels >= 1); checkArgument(numberOfExpectedBuffersPerChannel >= 1); this.inputGate = new SingleInputGate( "Test Name", new JobID(), new IntermediateDataSetID(), ResultPartitionType.PIPELINED, subpartitionIndex, numberOfInputChannels, mock(TaskActions.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup()); // Set buffer pool inputGate.setBufferPool(bufferPool); // Setup input channels for (int i = 0; i < numberOfInputChannels; i++) { inputGate.setInputChannel( new IntermediateResultPartitionID(), new LocalInputChannel( inputGate, i, consumedPartitionIds[i], partitionManager, taskEventDispatcher, new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup())); } this.numberOfInputChannels = numberOfInputChannels; this.numberOfExpectedBuffersPerChannel = numberOfExpectedBuffersPerChannel; } @Override public Void call() throws Exception { // One counter per input channel. Expect the same number of buffers from each channel. final int[] numberOfBuffersPerChannel = new int[numberOfInputChannels]; try { BufferOrEvent boe; while ((boe = inputGate.getNextBufferOrEvent()) != null) { if (boe.isBuffer()) { boe.getBuffer().recycle(); // Check that we don't receive too many buffers if (++numberOfBuffersPerChannel[boe.getChannelIndex()] > numberOfExpectedBuffersPerChannel) { throw new IllegalStateException("Received more buffers than expected " + "on channel " + boe.getChannelIndex() + "."); } } } // Verify that we received the expected number of buffers on each channel for (int i = 0; i < numberOfBuffersPerChannel.length; i++) { final int actualNumberOfReceivedBuffers = numberOfBuffersPerChannel[i]; if (actualNumberOfReceivedBuffers != numberOfExpectedBuffersPerChannel) { throw new IllegalStateException("Received unexpected number of buffers " + "on channel " + i + " (" + actualNumberOfReceivedBuffers + " instead " + "of " + numberOfExpectedBuffersPerChannel + ")."); } } } finally { inputGate.releaseAllResources(); } return null; } } }