/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.io.network.netty;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.embedded.EmbeddedChannel;
import org.apache.flink.core.memory.HeapMemorySegment;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferProvider;
import org.apache.flink.runtime.io.network.netty.NettyMessage.BufferResponse;
import org.apache.flink.runtime.io.network.netty.NettyMessage.ErrorResponse;
import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
import org.apache.flink.runtime.io.network.util.TestBufferFactory;
import org.apache.flink.runtime.testutils.DiscardingRecycler;
import org.apache.flink.runtime.util.event.EventListener;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class PartitionRequestClientHandlerTest {
/**
* Tests a fix for FLINK-1627.
*
* <p> FLINK-1627 discovered a race condition, which could lead to an infinite loop when a
* receiver was cancelled during a certain time of decoding a message. The test reproduces the
* input, which lead to the infinite loop: when the handler gets a reference to the buffer
* provider of the receiving input channel, but the respective input channel is released (and
* the corresponding buffer provider destroyed), the handler did not notice this.
*
* @see <a href="https://issues.apache.org/jira/browse/FLINK-1627">FLINK-1627</a>
*/
@Test(timeout = 60000)
@SuppressWarnings("unchecked")
public void testReleaseInputChannelDuringDecode() throws Exception {
// Mocks an input channel in a state as it was released during a decode.
final BufferProvider bufferProvider = mock(BufferProvider.class);
when(bufferProvider.requestBuffer()).thenReturn(null);
when(bufferProvider.isDestroyed()).thenReturn(true);
when(bufferProvider.addListener(any(EventListener.class))).thenReturn(false);
final RemoteInputChannel inputChannel = mock(RemoteInputChannel.class);
when(inputChannel.getInputChannelId()).thenReturn(new InputChannelID());
when(inputChannel.getBufferProvider()).thenReturn(bufferProvider);
final BufferResponse ReceivedBuffer = createBufferResponse(
TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId());
final PartitionRequestClientHandler client = new PartitionRequestClientHandler();
client.addInputChannel(inputChannel);
client.channelRead(mock(ChannelHandlerContext.class), ReceivedBuffer);
}
/**
* Tests a fix for FLINK-1761.
*
* <p> FLINK-1761 discovered an IndexOutOfBoundsException, when receiving buffers of size 0.
*/
@Test
public void testReceiveEmptyBuffer() throws Exception {
// Minimal mock of a remote input channel
final BufferProvider bufferProvider = mock(BufferProvider.class);
when(bufferProvider.requestBuffer()).thenReturn(TestBufferFactory.createBuffer());
final RemoteInputChannel inputChannel = mock(RemoteInputChannel.class);
when(inputChannel.getInputChannelId()).thenReturn(new InputChannelID());
when(inputChannel.getBufferProvider()).thenReturn(bufferProvider);
// An empty buffer of size 0
final Buffer emptyBuffer = TestBufferFactory.createBuffer();
emptyBuffer.setSize(0);
final BufferResponse receivedBuffer = createBufferResponse(
emptyBuffer, 0, inputChannel.getInputChannelId());
final PartitionRequestClientHandler client = new PartitionRequestClientHandler();
client.addInputChannel(inputChannel);
// Read the empty buffer
client.channelRead(mock(ChannelHandlerContext.class), receivedBuffer);
// This should not throw an exception
verify(inputChannel, never()).onError(any(Throwable.class));
}
/**
* Verifies that {@link RemoteInputChannel#onFailedPartitionRequest()} is called when a
* {@link PartitionNotFoundException} is received.
*/
@Test
public void testReceivePartitionNotFoundException() throws Exception {
// Minimal mock of a remote input channel
final BufferProvider bufferProvider = mock(BufferProvider.class);
when(bufferProvider.requestBuffer()).thenReturn(TestBufferFactory.createBuffer());
final RemoteInputChannel inputChannel = mock(RemoteInputChannel.class);
when(inputChannel.getInputChannelId()).thenReturn(new InputChannelID());
when(inputChannel.getBufferProvider()).thenReturn(bufferProvider);
final ErrorResponse partitionNotFound = new ErrorResponse(
new PartitionNotFoundException(new ResultPartitionID()),
inputChannel.getInputChannelId());
final PartitionRequestClientHandler client = new PartitionRequestClientHandler();
client.addInputChannel(inputChannel);
// Mock channel context
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
when(ctx.channel()).thenReturn(mock(Channel.class));
client.channelActive(ctx);
client.channelRead(ctx, partitionNotFound);
verify(inputChannel, times(1)).onFailedPartitionRequest();
}
@Test
public void testCancelBeforeActive() throws Exception {
final RemoteInputChannel inputChannel = mock(RemoteInputChannel.class);
when(inputChannel.getInputChannelId()).thenReturn(new InputChannelID());
final PartitionRequestClientHandler client = new PartitionRequestClientHandler();
client.addInputChannel(inputChannel);
// Don't throw NPE
client.cancelRequestFor(null);
// Don't throw NPE, because channel is not active yet
client.cancelRequestFor(inputChannel.getInputChannelId());
}
/**
* Tests that an unsuccessful message decode call for a staged message
* does not leave the channel with auto read set to false.
*/
@Test
@SuppressWarnings("unchecked")
public void testAutoReadAfterUnsuccessfulStagedMessage() throws Exception {
PartitionRequestClientHandler handler = new PartitionRequestClientHandler();
EmbeddedChannel channel = new EmbeddedChannel(handler);
final AtomicReference<EventListener<Buffer>> listener = new AtomicReference<>();
BufferProvider bufferProvider = mock(BufferProvider.class);
when(bufferProvider.addListener(any(EventListener.class))).thenAnswer(new Answer<Boolean>() {
@Override
@SuppressWarnings("unchecked")
public Boolean answer(InvocationOnMock invocation) throws Throwable {
listener.set((EventListener<Buffer>) invocation.getArguments()[0]);
return true;
}
});
when(bufferProvider.requestBuffer()).thenReturn(null);
InputChannelID channelId = new InputChannelID(0, 0);
RemoteInputChannel inputChannel = mock(RemoteInputChannel.class);
when(inputChannel.getInputChannelId()).thenReturn(channelId);
// The 3rd staged msg has a null buffer provider
when(inputChannel.getBufferProvider()).thenReturn(bufferProvider, bufferProvider, null);
handler.addInputChannel(inputChannel);
BufferResponse msg = createBufferResponse(createBuffer(true), 0, channelId);
// Write 1st buffer msg. No buffer is available, therefore the buffer
// should be staged and auto read should be set to false.
assertTrue(channel.config().isAutoRead());
channel.writeInbound(msg);
// No buffer available, auto read false
assertFalse(channel.config().isAutoRead());
// Write more buffers... all staged.
msg = createBufferResponse(createBuffer(true), 1, channelId);
channel.writeInbound(msg);
msg = createBufferResponse(createBuffer(true), 2, channelId);
channel.writeInbound(msg);
// Notify about buffer => handle 1st msg
Buffer availableBuffer = createBuffer(false);
listener.get().onEvent(availableBuffer);
// Start processing of staged buffers (in run pending tasks). Make
// sure that the buffer provider acts like it's destroyed.
when(bufferProvider.addListener(any(EventListener.class))).thenReturn(false);
when(bufferProvider.isDestroyed()).thenReturn(true);
// Execute all tasks that are scheduled in the event loop. Further
// eventLoop().execute() calls are directly executed, if they are
// called in the scope of this call.
channel.runPendingTasks();
assertTrue(channel.config().isAutoRead());
}
// ---------------------------------------------------------------------------------------------
private static Buffer createBuffer(boolean fill) {
MemorySegment segment = HeapMemorySegment.FACTORY.allocateUnpooledSegment(1024, null);
if (fill) {
for (int i = 0; i < 1024; i++) {
segment.put(i, (byte) i);
}
}
return new Buffer(segment, DiscardingRecycler.INSTANCE, true);
}
/**
* Returns a deserialized buffer message as it would be received during runtime.
*/
private BufferResponse createBufferResponse(
Buffer buffer,
int sequenceNumber,
InputChannelID receivingChannelId) throws IOException {
// Mock buffer to serialize
BufferResponse resp = new BufferResponse(buffer, sequenceNumber, receivingChannelId);
ByteBuf serialized = resp.write(UnpooledByteBufAllocator.DEFAULT);
// Skip general header bytes
serialized.readBytes(NettyMessage.HEADER_LENGTH);
BufferResponse deserialized = new BufferResponse();
// Deserialize the bytes again. We have to go this way, because we only partly deserialize
// the header of the response and wait for a buffer from the buffer pool to copy the payload
// data into.
deserialized.readFrom(serialized);
return deserialized;
}
}