/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.runtime.io.network.partition; import org.apache.flink.api.common.JobID; import org.apache.flink.core.testutils.CheckedThread; import org.apache.flink.runtime.io.network.ConnectionID; import org.apache.flink.runtime.io.network.ConnectionManager; import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel; import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup; import org.apache.flink.runtime.taskmanager.TaskActions; import org.junit.Test; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Random; import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createDummyConnectionManager; import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createResultPartitionManager; import static org.junit.Assert.assertNotNull; import static org.mockito.Mockito.mock; public class InputGateConcurrentTest { @Test public void testConsumptionWithLocalChannels() throws Exception { final int numChannels = 11; final int buffersPerChannel = 1000; final ResultPartition resultPartition = mock(ResultPartition.class); final PipelinedSubpartition[] partitions = new PipelinedSubpartition[numChannels]; final Source[] sources = new Source[numChannels]; final ResultPartitionManager resultPartitionManager = createResultPartitionManager(partitions); final SingleInputGate gate = new SingleInputGate( "Test Task Name", new JobID(), new IntermediateDataSetID(), ResultPartitionType.PIPELINED, 0, numChannels, mock(TaskActions.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup()); for (int i = 0; i < numChannels; i++) { LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(), resultPartitionManager, mock(TaskEventDispatcher.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup()); gate.setInputChannel(new IntermediateResultPartitionID(), channel); partitions[i] = new PipelinedSubpartition(0, resultPartition); sources[i] = new PipelinedSubpartitionSource(partitions[i]); } ProducerThread producer = new ProducerThread(sources, numChannels * buffersPerChannel, 4, 10); ConsumerThread consumer = new ConsumerThread(gate, numChannels * buffersPerChannel); producer.start(); consumer.start(); // the 'sync()' call checks for exceptions and failed assertions producer.sync(); consumer.sync(); } @Test public void testConsumptionWithRemoteChannels() throws Exception { final int numChannels = 11; final int buffersPerChannel = 1000; final ConnectionManager connManager = createDummyConnectionManager(); final Source[] sources = new Source[numChannels]; final SingleInputGate gate = new SingleInputGate( "Test Task Name", new JobID(), new IntermediateDataSetID(), ResultPartitionType.PIPELINED, 0, numChannels, mock(TaskActions.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup()); for (int i = 0; i < numChannels; i++) { RemoteInputChannel channel = new RemoteInputChannel( gate, i, new ResultPartitionID(), mock(ConnectionID.class), connManager, 0, 0, new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup()); gate.setInputChannel(new IntermediateResultPartitionID(), channel); sources[i] = new RemoteChannelSource(channel); } ProducerThread producer = new ProducerThread(sources, numChannels * buffersPerChannel, 4, 10); ConsumerThread consumer = new ConsumerThread(gate, numChannels * buffersPerChannel); producer.start(); consumer.start(); // the 'sync()' call checks for exceptions and failed assertions producer.sync(); consumer.sync(); } @Test public void testConsumptionWithMixedChannels() throws Exception { final int numChannels = 61; final int numLocalChannels = 20; final int buffersPerChannel = 1000; // fill the local/remote decision List<Boolean> localOrRemote = new ArrayList<>(numChannels); for (int i = 0; i < numChannels; i++) { localOrRemote.add(i < numLocalChannels); } Collections.shuffle(localOrRemote); final ConnectionManager connManager = createDummyConnectionManager(); final ResultPartition resultPartition = mock(ResultPartition.class); final PipelinedSubpartition[] localPartitions = new PipelinedSubpartition[numLocalChannels]; final ResultPartitionManager resultPartitionManager = createResultPartitionManager(localPartitions); final Source[] sources = new Source[numChannels]; final SingleInputGate gate = new SingleInputGate( "Test Task Name", new JobID(), new IntermediateDataSetID(), ResultPartitionType.PIPELINED, 0, numChannels, mock(TaskActions.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup()); for (int i = 0, local = 0; i < numChannels; i++) { if (localOrRemote.get(i)) { // local channel PipelinedSubpartition psp = new PipelinedSubpartition(0, resultPartition); localPartitions[local++] = psp; sources[i] = new PipelinedSubpartitionSource(psp); LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(), resultPartitionManager, mock(TaskEventDispatcher.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup()); gate.setInputChannel(new IntermediateResultPartitionID(), channel); } else { //remote channel RemoteInputChannel channel = new RemoteInputChannel( gate, i, new ResultPartitionID(), mock(ConnectionID.class), connManager, 0, 0, new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup()); gate.setInputChannel(new IntermediateResultPartitionID(), channel); sources[i] = new RemoteChannelSource(channel); } } ProducerThread producer = new ProducerThread(sources, numChannels * buffersPerChannel, 4, 10); ConsumerThread consumer = new ConsumerThread(gate, numChannels * buffersPerChannel); producer.start(); consumer.start(); // the 'sync()' call checks for exceptions and failed assertions producer.sync(); consumer.sync(); } // ------------------------------------------------------------------------ // testing threads // ------------------------------------------------------------------------ private static abstract class Source { abstract void addBuffer(Buffer buffer) throws Exception; } private static class PipelinedSubpartitionSource extends Source { final PipelinedSubpartition partition; PipelinedSubpartitionSource(PipelinedSubpartition partition) { this.partition = partition; } @Override void addBuffer(Buffer buffer) throws Exception { partition.add(buffer); } } private static class RemoteChannelSource extends Source { final RemoteInputChannel channel; private int seq = 0; RemoteChannelSource(RemoteInputChannel channel) { this.channel = channel; } @Override void addBuffer(Buffer buffer) throws Exception { channel.onBuffer(buffer, seq++); } } // ------------------------------------------------------------------------ // testing threads // ------------------------------------------------------------------------ private static class ProducerThread extends CheckedThread { private final Random rnd = new Random(); private final Source[] sources; private final int numTotal; private final int maxChunk; private final int yieldAfter; ProducerThread(Source[] sources, int numTotal, int maxChunk, int yieldAfter) { this.sources = sources; this.numTotal = numTotal; this.maxChunk = maxChunk; this.yieldAfter = yieldAfter; } @Override public void go() throws Exception { final Buffer buffer = InputChannelTestUtils.createMockBuffer(100); int nextYield = numTotal - yieldAfter; for (int i = numTotal; i > 0;) { final int nextChannel = rnd.nextInt(sources.length); final int chunk = Math.min(i, rnd.nextInt(maxChunk) + 1); final Source next = sources[nextChannel]; for (int k = chunk; k > 0; --k) { next.addBuffer(buffer); } i -= chunk; if (i <= nextYield) { nextYield -= yieldAfter; //noinspection CallToThreadYield Thread.yield(); } } } } private static class ConsumerThread extends CheckedThread { private final SingleInputGate gate; private final int numBuffers; ConsumerThread(SingleInputGate gate, int numBuffers) { this.gate = gate; this.numBuffers = numBuffers; } @Override public void go() throws Exception { for (int i = numBuffers; i > 0; --i) { assertNotNull(gate.getNextBufferOrEvent()); } } } }