/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.io.network;
import org.apache.flink.api.common.JobID;
import org.apache.flink.core.memory.MemoryType;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.io.network.buffer.BufferPool;
import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
import org.apache.flink.runtime.io.network.partition.ResultPartition;
import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.runtime.query.KvStateRegistry;
import org.apache.flink.runtime.taskmanager.Task;
import org.apache.flink.runtime.taskmanager.TaskActions;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Various tests for the {@link NetworkEnvironment} class.
*/
public class NetworkEnvironmentTest {
private final static int numBuffers = 1024;
private final static int memorySegmentSize = 128;
/**
* Verifies that {@link NetworkEnvironment#registerTask(Task)} sets up (un)bounded buffer pool
* instances for various types of input and output channels.
*/
@Test
public void testRegisterTaskUsesBoundedBuffers() throws Exception {
final NetworkEnvironment network = new NetworkEnvironment(
new NetworkBufferPool(numBuffers, memorySegmentSize, MemoryType.HEAP),
new LocalConnectionManager(),
new ResultPartitionManager(),
new TaskEventDispatcher(),
new KvStateRegistry(),
null,
IOManager.IOMode.SYNC,
0,
0,
2,
8);
// result partitions
ResultPartition rp1 = createResultPartition(ResultPartitionType.PIPELINED, 2);
ResultPartition rp2 = createResultPartition(ResultPartitionType.BLOCKING, 2);
ResultPartition rp3 = createResultPartition(ResultPartitionType.PIPELINED_BOUNDED, 2);
ResultPartition rp4 = createResultPartition(ResultPartitionType.PIPELINED_BOUNDED, 8);
final ResultPartition[] resultPartitions = new ResultPartition[] {rp1, rp2, rp3, rp4};
final ResultPartitionWriter[] resultPartitionWriters = new ResultPartitionWriter[] {
new ResultPartitionWriter(rp1), new ResultPartitionWriter(rp2),
new ResultPartitionWriter(rp3), new ResultPartitionWriter(rp4)};
// input gates
final SingleInputGate[] inputGates = new SingleInputGate[] {
createSingleInputGateMock(ResultPartitionType.PIPELINED, 2),
createSingleInputGateMock(ResultPartitionType.BLOCKING, 2),
createSingleInputGateMock(ResultPartitionType.PIPELINED_BOUNDED, 2),
createSingleInputGateMock(ResultPartitionType.PIPELINED_BOUNDED, 8)};
// overall task to register
Task task = mock(Task.class);
when(task.getProducedPartitions()).thenReturn(resultPartitions);
when(task.getAllWriters()).thenReturn(resultPartitionWriters);
when(task.getAllInputGates()).thenReturn(inputGates);
network.registerTask(task);
assertEquals(Integer.MAX_VALUE, rp1.getBufferPool().getMaxNumberOfMemorySegments());
assertEquals(Integer.MAX_VALUE, rp2.getBufferPool().getMaxNumberOfMemorySegments());
assertEquals(2 * 2 + 8, rp3.getBufferPool().getMaxNumberOfMemorySegments());
assertEquals(8 * 2 + 8, rp4.getBufferPool().getMaxNumberOfMemorySegments());
network.shutdown();
}
/**
* Helper to create simple {@link ResultPartition} instance for use by a {@link Task} inside
* {@link NetworkEnvironment#registerTask(Task)}.
*
* @param partitionType
* the produced partition type
* @param channels
* the nummer of output channels
*
* @return instance with minimal data set and some mocks so that it is useful for {@link
* NetworkEnvironment#registerTask(Task)}
*/
private static ResultPartition createResultPartition(
final ResultPartitionType partitionType, final int channels) {
return new ResultPartition(
"TestTask-" + partitionType + ":" + channels,
mock(TaskActions.class),
new JobID(),
new ResultPartitionID(),
partitionType,
channels,
channels,
mock(ResultPartitionManager.class),
mock(ResultPartitionConsumableNotifier.class),
mock(IOManager.class),
false);
}
/**
* Helper to create a mock of a {@link SingleInputGate} for use by a {@link Task} inside
* {@link NetworkEnvironment#registerTask(Task)}.
*
* @param partitionType
* the consumed partition type
* @param channels
* the nummer of input channels
*
* @return mock with minimal functionality necessary by {@link NetworkEnvironment#registerTask(Task)}
*/
private static SingleInputGate createSingleInputGateMock(
final ResultPartitionType partitionType, final int channels) {
SingleInputGate ig = mock(SingleInputGate.class);
when(ig.getConsumedPartitionType()).thenReturn(partitionType);
when(ig.getNumberOfInputChannels()).thenReturn(channels);
doAnswer(new Answer<Void>() {
@Override
public Void answer(final InvocationOnMock invocation) throws Throwable {
BufferPool bp = invocation.getArgumentAt(0, BufferPool.class);
if (partitionType == ResultPartitionType.PIPELINED_BOUNDED) {
assertEquals(channels * 2 + 8, bp.getMaxNumberOfMemorySegments());
} else {
assertEquals(Integer.MAX_VALUE, bp.getMaxNumberOfMemorySegments());
}
return null;
}
}).when(ig).setBufferPool(any(BufferPool.class));
return ig;
}
}