/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.io.network.partition;
import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.io.network.ConnectionID;
import org.apache.flink.runtime.io.network.ConnectionManager;
import org.apache.flink.runtime.io.network.TaskEventDispatcher;
import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
import org.apache.flink.runtime.taskmanager.TaskActions;
import org.junit.Test;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createDummyConnectionManager;
import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createMockBuffer;
import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createResultPartitionManager;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
public class InputGateFairnessTest {
@Test
public void testFairConsumptionLocalChannelsPreFilled() throws Exception {
final int numChannels = 37;
final int buffersPerChannel = 27;
final ResultPartition resultPartition = mock(ResultPartition.class);
final Buffer mockBuffer = createMockBuffer(42);
// ----- create some source channels and fill them with buffers -----
final PipelinedSubpartition[] sources = new PipelinedSubpartition[numChannels];
for (int i = 0; i < numChannels; i++) {
PipelinedSubpartition partition = new PipelinedSubpartition(0, resultPartition);
for (int p = 0; p < buffersPerChannel; p++) {
partition.add(mockBuffer);
}
partition.finish();
sources[i] = partition;
}
// ----- create reading side -----
ResultPartitionManager resultPartitionManager = createResultPartitionManager(sources);
SingleInputGate gate = new FairnessVerifyingInputGate(
"Test Task Name",
new JobID(),
new IntermediateDataSetID(),
0, numChannels,
mock(TaskActions.class),
new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
for (int i = 0; i < numChannels; i++) {
LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(),
resultPartitionManager, mock(TaskEventDispatcher.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
gate.setInputChannel(new IntermediateResultPartitionID(), channel);
}
// read all the buffers and the EOF event
for (int i = numChannels * (buffersPerChannel + 1); i > 0; --i) {
assertNotNull(gate.getNextBufferOrEvent());
int min = Integer.MAX_VALUE;
int max = 0;
for (PipelinedSubpartition source : sources) {
int size = source.getCurrentNumberOfBuffers();
min = Math.min(min, size);
max = Math.max(max, size);
}
assertTrue(max == min || max == min+1);
}
assertNull(gate.getNextBufferOrEvent());
}
@Test
public void testFairConsumptionLocalChannels() throws Exception {
final int numChannels = 37;
final int buffersPerChannel = 27;
final ResultPartition resultPartition = mock(ResultPartition.class);
final Buffer mockBuffer = createMockBuffer(42);
// ----- create some source channels and fill them with one buffer each -----
final PipelinedSubpartition[] sources = new PipelinedSubpartition[numChannels];
for (int i = 0; i < numChannels; i++) {
sources[i] = new PipelinedSubpartition(0, resultPartition);
}
// ----- create reading side -----
ResultPartitionManager resultPartitionManager = createResultPartitionManager(sources);
SingleInputGate gate = new FairnessVerifyingInputGate(
"Test Task Name",
new JobID(),
new IntermediateDataSetID(),
0, numChannels,
mock(TaskActions.class),
new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
for (int i = 0; i < numChannels; i++) {
LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(),
resultPartitionManager, mock(TaskEventDispatcher.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
gate.setInputChannel(new IntermediateResultPartitionID(), channel);
}
// seed one initial buffer
sources[12].add(mockBuffer);
// read all the buffers and the EOF event
for (int i = 0; i < numChannels * buffersPerChannel; i++) {
assertNotNull(gate.getNextBufferOrEvent());
int min = Integer.MAX_VALUE;
int max = 0;
for (PipelinedSubpartition source : sources) {
int size = source.getCurrentNumberOfBuffers();
min = Math.min(min, size);
max = Math.max(max, size);
}
assertTrue(max == min || max == min+1);
if (i % (2 * numChannels) == 0) {
// add three buffers to each channel, in random order
fillRandom(sources, 3, mockBuffer);
}
}
// there is still more in the queues
}
@Test
public void testFairConsumptionRemoteChannelsPreFilled() throws Exception {
final int numChannels = 37;
final int buffersPerChannel = 27;
final Buffer mockBuffer = createMockBuffer(42);
// ----- create some source channels and fill them with buffers -----
SingleInputGate gate = new FairnessVerifyingInputGate(
"Test Task Name",
new JobID(),
new IntermediateDataSetID(),
0, numChannels,
mock(TaskActions.class),
new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
final ConnectionManager connManager = createDummyConnectionManager();
final RemoteInputChannel[] channels = new RemoteInputChannel[numChannels];
for (int i = 0; i < numChannels; i++) {
RemoteInputChannel channel = new RemoteInputChannel(
gate, i, new ResultPartitionID(), mock(ConnectionID.class),
connManager, 0, 0, new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
channels[i] = channel;
for (int p = 0; p < buffersPerChannel; p++) {
channel.onBuffer(mockBuffer, p);
}
channel.onBuffer(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), buffersPerChannel);
gate.setInputChannel(new IntermediateResultPartitionID(), channel);
}
// read all the buffers and the EOF event
for (int i = numChannels * (buffersPerChannel + 1); i > 0; --i) {
assertNotNull(gate.getNextBufferOrEvent());
int min = Integer.MAX_VALUE;
int max = 0;
for (RemoteInputChannel channel : channels) {
int size = channel.getNumberOfQueuedBuffers();
min = Math.min(min, size);
max = Math.max(max, size);
}
assertTrue(max == min || max == min+1);
}
assertNull(gate.getNextBufferOrEvent());
}
@Test
public void testFairConsumptionRemoteChannels() throws Exception {
final int numChannels = 37;
final int buffersPerChannel = 27;
final Buffer mockBuffer = createMockBuffer(42);
// ----- create some source channels and fill them with buffers -----
SingleInputGate gate = new FairnessVerifyingInputGate(
"Test Task Name",
new JobID(),
new IntermediateDataSetID(),
0, numChannels,
mock(TaskActions.class),
new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
final ConnectionManager connManager = createDummyConnectionManager();
final RemoteInputChannel[] channels = new RemoteInputChannel[numChannels];
final int[] channelSequenceNums = new int[numChannels];
for (int i = 0; i < numChannels; i++) {
RemoteInputChannel channel = new RemoteInputChannel(
gate, i, new ResultPartitionID(), mock(ConnectionID.class),
connManager, 0, 0, new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
channels[i] = channel;
gate.setInputChannel(new IntermediateResultPartitionID(), channel);
}
channels[11].onBuffer(mockBuffer, 0);
channelSequenceNums[11]++;
// read all the buffers and the EOF event
for (int i = 0; i < numChannels * buffersPerChannel; i++) {
assertNotNull(gate.getNextBufferOrEvent());
int min = Integer.MAX_VALUE;
int max = 0;
for (RemoteInputChannel channel : channels) {
int size = channel.getNumberOfQueuedBuffers();
min = Math.min(min, size);
max = Math.max(max, size);
}
assertTrue(max == min || max == min+1);
if (i % (2 * numChannels) == 0) {
// add three buffers to each channel, in random order
fillRandom(channels, channelSequenceNums, 3, mockBuffer);
}
}
}
// ------------------------------------------------------------------------
// Utilities
// ------------------------------------------------------------------------
private void fillRandom(PipelinedSubpartition[] partitions, int numPerPartition, Buffer buffer) throws Exception {
ArrayList<Integer> poss = new ArrayList<>(partitions.length * numPerPartition);
for (int i = 0; i < partitions.length; i++) {
for (int k = 0; k < numPerPartition; k++) {
poss.add(i);
}
}
Collections.shuffle(poss);
for (Integer i : poss) {
partitions[i].add(buffer);
}
}
private void fillRandom(
RemoteInputChannel[] partitions,
int[] sequenceNumbers,
int numPerPartition,
Buffer buffer) throws Exception {
ArrayList<Integer> poss = new ArrayList<>(partitions.length * numPerPartition);
for (int i = 0; i < partitions.length; i++) {
for (int k = 0; k < numPerPartition; k++) {
poss.add(i);
}
}
Collections.shuffle(poss);
for (int i : poss) {
partitions[i].onBuffer(buffer, sequenceNumbers[i]++);
}
}
// ------------------------------------------------------------------------
private static class FairnessVerifyingInputGate extends SingleInputGate {
private final ArrayDeque<InputChannel> channelsWithData;
private final HashSet<InputChannel> uniquenessChecker;
@SuppressWarnings("unchecked")
public FairnessVerifyingInputGate(
String owningTaskName,
JobID jobId,
IntermediateDataSetID consumedResultId,
int consumedSubpartitionIndex,
int numberOfInputChannels,
TaskActions taskActions,
TaskIOMetricGroup metrics) {
super(owningTaskName, jobId, consumedResultId, ResultPartitionType.PIPELINED,
consumedSubpartitionIndex,
numberOfInputChannels, taskActions, metrics);
try {
Field f = SingleInputGate.class.getDeclaredField("inputChannelsWithData");
f.setAccessible(true);
channelsWithData = (ArrayDeque<InputChannel>) f.get(this);
}
catch (Exception e) {
throw new RuntimeException(e);
}
this.uniquenessChecker = new HashSet<>();
}
@Override
public BufferOrEvent getNextBufferOrEvent() throws IOException, InterruptedException {
synchronized (channelsWithData) {
assertTrue("too many input channels", channelsWithData.size() <= getNumberOfInputChannels());
ensureUnique(channelsWithData);
}
return super.getNextBufferOrEvent();
}
private void ensureUnique(Collection<InputChannel> channels) {
HashSet<InputChannel> uniquenessChecker = this.uniquenessChecker;
for (InputChannel channel : channels) {
if (!uniquenessChecker.add(channel)) {
fail("Duplicate channel in input gate: " + channel);
}
}
assertTrue("found duplicate input channels", uniquenessChecker.size() == channels.size());
uniquenessChecker.clear();
}
}
}