/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.runtime.io.network.netty; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPromise; import io.netty.channel.embedded.EmbeddedChannel; import org.apache.flink.runtime.io.network.ConnectionID; import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.netty.NettyTestUtil.NettyServerAndClient; import org.apache.flink.runtime.io.network.netty.exception.LocalTransportException; import org.apache.flink.runtime.io.network.netty.exception.RemoteTransportException; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultPartitionProvider; import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID; import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; import org.apache.flink.runtime.testingUtils.TestingUtils; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import java.io.IOException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import static org.apache.flink.runtime.io.network.netty.NettyTestUtil.connect; import static org.apache.flink.runtime.io.network.netty.NettyTestUtil.createConfig; import static org.apache.flink.runtime.io.network.netty.NettyTestUtil.initServerAndClient; import static org.apache.flink.runtime.io.network.netty.NettyTestUtil.shutdown; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.isA; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class ClientTransportErrorHandlingTest { /** * Verifies that failed client requests via {@link PartitionRequestClient} are correctly * attributed to the respective {@link RemoteInputChannel}. */ @Test public void testExceptionOnWrite() throws Exception { NettyProtocol protocol = new NettyProtocol() { @Override public ChannelHandler[] getServerChannelHandlers() { return new ChannelHandler[0]; } @Override public ChannelHandler[] getClientChannelHandlers() { return new PartitionRequestProtocol( mock(ResultPartitionProvider.class), mock(TaskEventDispatcher.class)).getClientChannelHandlers(); } }; // We need a real server and client in this test, because Netty's EmbeddedChannel is // not failing the ChannelPromise of failed writes. NettyServerAndClient serverAndClient = initServerAndClient(protocol, createConfig()); Channel ch = connect(serverAndClient); PartitionRequestClientHandler handler = getClientHandler(ch); // Last outbound handler throws Exception after 1st write ch.pipeline().addFirst(new ChannelOutboundHandlerAdapter() { int writeNum = 0; @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { if (writeNum >= 1) { throw new RuntimeException("Expected test exception."); } writeNum++; ctx.write(msg, promise); } }); PartitionRequestClient requestClient = new PartitionRequestClient( ch, handler, mock(ConnectionID.class), mock(PartitionRequestClientFactory.class)); // Create input channels RemoteInputChannel[] rich = new RemoteInputChannel[] { createRemoteInputChannel(), createRemoteInputChannel()}; final CountDownLatch sync = new CountDownLatch(1); // Do this with explicit synchronization. Otherwise this is not robust against slow timings // of the callback (e.g. we cannot just verify that it was called once, because there is // a chance that we do this too early). doAnswer(new Answer<Void>() { @Override public Void answer(InvocationOnMock invocation) throws Throwable { sync.countDown(); return null; } }).when(rich[1]).onError(isA(LocalTransportException.class)); // First request is successful ChannelFuture f = requestClient.requestSubpartition(new ResultPartitionID(), 0, rich[0], 0); assertTrue(f.await().isSuccess()); // Second request is *not* successful f = requestClient.requestSubpartition(new ResultPartitionID(), 0, rich[1], 0); assertFalse(f.await().isSuccess()); // Only the second channel should be notified about the error verify(rich[0], times(0)).onError(any(LocalTransportException.class)); // Wait for the notification if (!sync.await(TestingUtils.TESTING_DURATION().toMillis(), TimeUnit.MILLISECONDS)) { fail("Timed out after waiting for " + TestingUtils.TESTING_DURATION().toMillis() + " ms to be notified about the channel error."); } shutdown(serverAndClient); } /** * Verifies that {@link NettyMessage.ErrorResponse} messages are correctly wrapped in * {@link RemoteTransportException} instances. */ @Test public void testWrappingOfRemoteErrorMessage() throws Exception { EmbeddedChannel ch = createEmbeddedChannel(); PartitionRequestClientHandler handler = getClientHandler(ch); // Create input channels RemoteInputChannel[] rich = new RemoteInputChannel[] { createRemoteInputChannel(), createRemoteInputChannel()}; for (RemoteInputChannel r : rich) { when(r.getInputChannelId()).thenReturn(new InputChannelID()); handler.addInputChannel(r); } // Error msg for channel[0] ch.pipeline().fireChannelRead(new NettyMessage.ErrorResponse( new RuntimeException("Expected test exception"), rich[0].getInputChannelId())); try { // Exception should not reach end of pipeline... ch.checkException(); } catch (Exception e) { fail("The exception reached the end of the pipeline and " + "was not handled correctly by the last handler."); } verify(rich[0], times(1)).onError(isA(RemoteTransportException.class)); verify(rich[1], never()).onError(any(Throwable.class)); // Fatal error for all channels ch.pipeline().fireChannelRead(new NettyMessage.ErrorResponse( new RuntimeException("Expected test exception"))); try { // Exception should not reach end of pipeline... ch.checkException(); } catch (Exception e) { fail("The exception reached the end of the pipeline and " + "was not handled correctly by the last handler."); } verify(rich[0], times(2)).onError(isA(RemoteTransportException.class)); verify(rich[1], times(1)).onError(isA(RemoteTransportException.class)); } /** * Verifies that unexpected remote closes are reported as an instance of * {@link RemoteTransportException}. */ @Test public void testExceptionOnRemoteClose() throws Exception { NettyProtocol protocol = new NettyProtocol() { @Override public ChannelHandler[] getServerChannelHandlers() { return new ChannelHandler[] { // Close on read new ChannelInboundHandlerAdapter() { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { ctx.channel().close(); } } }; } @Override public ChannelHandler[] getClientChannelHandlers() { return new PartitionRequestProtocol( mock(ResultPartitionProvider.class), mock(TaskEventDispatcher.class)).getClientChannelHandlers(); } }; NettyServerAndClient serverAndClient = initServerAndClient(protocol, createConfig()); Channel ch = connect(serverAndClient); PartitionRequestClientHandler handler = getClientHandler(ch); // Create input channels RemoteInputChannel[] rich = new RemoteInputChannel[] { createRemoteInputChannel(), createRemoteInputChannel()}; final CountDownLatch sync = new CountDownLatch(rich.length); Answer<Void> countDownLatch = new Answer<Void>() { @Override public Void answer(InvocationOnMock invocation) throws Throwable { sync.countDown(); return null; } }; for (RemoteInputChannel r : rich) { doAnswer(countDownLatch).when(r).onError(any(Throwable.class)); handler.addInputChannel(r); } // Write something to trigger close by server ch.writeAndFlush(Unpooled.buffer().writerIndex(16)); // Wait for the notification if (!sync.await(TestingUtils.TESTING_DURATION().toMillis(), TimeUnit.MILLISECONDS)) { fail("Timed out after waiting for " + TestingUtils.TESTING_DURATION().toMillis() + " ms to be notified about remote connection close."); } // All the registered channels should be notified. for (RemoteInputChannel r : rich) { verify(r).onError(isA(RemoteTransportException.class)); } shutdown(serverAndClient); } /** * Verifies that fired Exceptions are handled correctly by the pipeline. */ @Test public void testExceptionCaught() throws Exception { EmbeddedChannel ch = createEmbeddedChannel(); PartitionRequestClientHandler handler = getClientHandler(ch); // Create input channels RemoteInputChannel[] rich = new RemoteInputChannel[] { createRemoteInputChannel(), createRemoteInputChannel()}; for (RemoteInputChannel r : rich) { when(r.getInputChannelId()).thenReturn(new InputChannelID()); handler.addInputChannel(r); } ch.pipeline().fireExceptionCaught(new Exception()); try { // Exception should not reach end of pipeline... ch.checkException(); } catch (Exception e) { fail("The exception reached the end of the pipeline and " + "was not handled correctly by the last handler."); } // ...but all the registered channels should be notified. for (RemoteInputChannel r : rich) { verify(r).onError(isA(LocalTransportException.class)); } } /** * Verifies that "Connection reset by peer" Exceptions are special-cased and are reported as * an instance of {@link RemoteTransportException}. */ @Test public void testConnectionResetByPeer() throws Throwable { EmbeddedChannel ch = createEmbeddedChannel(); PartitionRequestClientHandler handler = getClientHandler(ch); RemoteInputChannel rich = addInputChannel(handler); final Throwable[] error = new Throwable[1]; // Verify the Exception doAnswer(new Answer<Void>() { @Override public Void answer(InvocationOnMock invocation) throws Throwable { Throwable cause = (Throwable) invocation.getArguments()[0]; try { assertEquals(RemoteTransportException.class, cause.getClass()); assertNotEquals("Connection reset by peer", cause.getMessage()); assertEquals(IOException.class, cause.getCause().getClass()); assertEquals("Connection reset by peer", cause.getCause().getMessage()); } catch (Throwable t) { error[0] = t; } return null; } }).when(rich).onError(any(Throwable.class)); ch.pipeline().fireExceptionCaught(new IOException("Connection reset by peer")); assertNull(error[0]); } /** * Verifies that the channel is closed if there is an error *during* error notification. */ @Test public void testChannelClosedOnExceptionDuringErrorNotification() throws Exception { EmbeddedChannel ch = createEmbeddedChannel(); PartitionRequestClientHandler handler = getClientHandler(ch); RemoteInputChannel rich = addInputChannel(handler); doThrow(new RuntimeException("Expected test exception")) .when(rich).onError(any(Throwable.class)); ch.pipeline().fireExceptionCaught(new Exception()); assertFalse(ch.isActive()); } // --------------------------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------------------------- private EmbeddedChannel createEmbeddedChannel() { PartitionRequestProtocol protocol = new PartitionRequestProtocol( mock(ResultPartitionProvider.class), mock(TaskEventDispatcher.class)); return new EmbeddedChannel(protocol.getClientChannelHandlers()); } private RemoteInputChannel addInputChannel(PartitionRequestClientHandler clientHandler) throws IOException { RemoteInputChannel rich = createRemoteInputChannel(); clientHandler.addInputChannel(rich); return rich; } private PartitionRequestClientHandler getClientHandler(Channel ch) { return ch.pipeline().get(PartitionRequestClientHandler.class); } private RemoteInputChannel createRemoteInputChannel() { return when(mock(RemoteInputChannel.class) .getInputChannelId()) .thenReturn(new InputChannelID()).getMock(); } }