/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.runtime.io.disk.iomanager; import org.apache.flink.core.memory.HeapMemorySegment; import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.util.TestNotificationListener; import org.junit.Test; import org.junit.runner.RunWith; import org.powermock.modules.junit4.PowerMockRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.Random; import java.util.concurrent.Callable; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; @RunWith(PowerMockRunner.class) public class AsynchronousFileIOChannelTest { private static final Logger LOG = LoggerFactory.getLogger(AsynchronousFileIOChannelTest.class); @Test public void testAllRequestsProcessedListenerNotification() throws Exception { // -- Config ---------------------------------------------------------- final int numberOfRuns = 10; final int numberOfRequests = 100; // -- Setup ----------------------------------------------------------- final IOManagerAsync ioManager = new IOManagerAsync(); final ExecutorService executor = Executors.newFixedThreadPool(3); final Random random = new Random(); final RequestQueue<WriteRequest> requestQueue = new RequestQueue<WriteRequest>(); final RequestDoneCallback<Buffer> ioChannelCallback = mock(RequestDoneCallback.class); final TestNotificationListener listener = new TestNotificationListener(); // -- The Test -------------------------------------------------------- try { // Repeatedly add requests and process them and have one thread try to register as a // listener until the channel is closed and all requests are processed. for (int run = 0; run < numberOfRuns; run++) { final TestAsyncFileIOChannel ioChannel = new TestAsyncFileIOChannel( ioManager.createChannel(), requestQueue, ioChannelCallback, true); final CountDownLatch sync = new CountDownLatch(3); // The mock requests final Buffer buffer = mock(Buffer.class); final WriteRequest request = mock(WriteRequest.class); // Add requests task Callable<Void> addRequestsTask = new Callable<Void>() { @Override public Void call() throws Exception { for (int i = 0; i < numberOfRuns; i++) { LOG.debug("Starting run {}.", i + 1); for (int j = 0; j < numberOfRequests; j++) { ioChannel.addRequest(request); } LOG.debug("Added all ({}) requests of run {}.", numberOfRequests, i + 1); int sleep = random.nextInt(10); LOG.debug("Sleeping for {} ms before next run.", sleep); Thread.sleep(sleep); } LOG.debug("Done. Closing channel."); ioChannel.close(); sync.countDown(); return null; } }; // Process requests task Callable<Void> processRequestsTask = new Callable<Void>() { @Override public Void call() throws Exception { int total = numberOfRequests * numberOfRuns; for (int i = 0; i < total; i++) { requestQueue.take(); ioChannel.handleProcessedBuffer(buffer, null); } LOG.debug("Processed all ({}) requests.", numberOfRequests); sync.countDown(); return null; } }; // Listener Callable<Void> registerListenerTask = new Callable<Void>() { @Override public Void call() throws Exception { while (true) { int current = listener.getNumberOfNotifications(); if (ioChannel.registerAllRequestsProcessedListener(listener)) { listener.waitForNotification(current); } else if (ioChannel.isClosed()) { break; } } LOG.debug("Stopping listener. Channel closed."); sync.countDown(); return null; } }; // Run tasks in random order final List<Callable<?>> tasks = new LinkedList<Callable<?>>(); tasks.add(addRequestsTask); tasks.add(processRequestsTask); tasks.add(registerListenerTask); Collections.shuffle(tasks); for (Callable<?> task : tasks) { executor.submit(task); } if (!sync.await(2, TimeUnit.MINUTES)) { fail("Test failed due to a timeout. This indicates a deadlock due to the way" + "that listeners are registered/notified in the asynchronous file I/O" + "channel."); } listener.reset(); } } finally { ioManager.shutdown(); executor.shutdown(); } } @Test public void testClosedButAddRequestAndRegisterListenerRace() throws Exception { // -- Config ---------------------------------------------------------- final int numberOfRuns = 1024; // -- Setup ----------------------------------------------------------- final IOManagerAsync ioManager = new IOManagerAsync(); final ExecutorService executor = Executors.newFixedThreadPool(2); final RequestQueue<WriteRequest> requestQueue = new RequestQueue<WriteRequest>(); @SuppressWarnings("unchecked") final RequestDoneCallback<Buffer> ioChannelCallback = mock(RequestDoneCallback.class); final TestNotificationListener listener = new TestNotificationListener(); // -- The Test -------------------------------------------------------- try { // Repeatedly close the channel and add a request. for (int i = 0; i < numberOfRuns; i++) { final TestAsyncFileIOChannel ioChannel = new TestAsyncFileIOChannel( ioManager.createChannel(), requestQueue, ioChannelCallback, true); final CountDownLatch sync = new CountDownLatch(2); final WriteRequest request = mock(WriteRequest.class); ioChannel.close(); // Add request task Callable<Void> addRequestTask = new Callable<Void>() { @Override public Void call() throws Exception { try { ioChannel.addRequest(request); } catch (Throwable expected) { } finally { sync.countDown(); } return null; } }; // Listener Callable<Void> registerListenerTask = new Callable<Void>() { @Override public Void call() throws Exception { try { while (true) { int current = listener.getNumberOfNotifications(); if (ioChannel.registerAllRequestsProcessedListener(listener)) { listener.waitForNotification(current); } else if (ioChannel.isClosed()) { break; } } } finally { sync.countDown(); } return null; } }; executor.submit(addRequestTask); executor.submit(registerListenerTask); if (!sync.await(2, TimeUnit.MINUTES)) { fail("Test failed due to a timeout. This indicates a deadlock due to the way" + "that listeners are registered/notified in the asynchronous file I/O" + "channel."); } } } finally { ioManager.shutdown(); executor.shutdown(); } } @Test public void testClosingWaits() { IOManagerAsync ioMan = new IOManagerAsync(); try { final int NUM_BLOCKS = 100; final MemorySegment seg = MemorySegmentFactory.allocateUnpooledSegment(32 * 1024); final AtomicInteger callbackCounter = new AtomicInteger(); final AtomicBoolean exceptionOccurred = new AtomicBoolean(); final RequestDoneCallback<MemorySegment> callback = new RequestDoneCallback<MemorySegment>() { @Override public void requestSuccessful(MemorySegment buffer) { // we do the non safe variant. the callbacks should come in order from // the same thread, so it should always work callbackCounter.set(callbackCounter.get() + 1); if (buffer != seg) { exceptionOccurred.set(true); } } @Override public void requestFailed(MemorySegment buffer, IOException e) { exceptionOccurred.set(true); } }; BlockChannelWriterWithCallback<MemorySegment> writer = ioMan.createBlockChannelWriter(ioMan.createChannel(), callback); try { for (int i = 0; i < NUM_BLOCKS; i++) { writer.writeBlock(seg); } writer.close(); assertEquals(NUM_BLOCKS, callbackCounter.get()); assertFalse(exceptionOccurred.get()); } finally { writer.closeAndDelete(); } } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } finally { ioMan.shutdown(); } } @Test public void testExceptionForwardsToClose() { IOManagerAsync ioMan = new IOManagerAsync(); try { testExceptionForwardsToClose(ioMan, 100, 1); testExceptionForwardsToClose(ioMan, 100, 50); testExceptionForwardsToClose(ioMan, 100, 100); } finally { ioMan.shutdown(); } } private void testExceptionForwardsToClose(IOManagerAsync ioMan, final int numBlocks, final int failingBlock) { try { MemorySegment seg = MemorySegmentFactory.allocateUnpooledSegment(32 * 1024); FileIOChannel.ID channelId = ioMan.createChannel(); BlockChannelWriterWithCallback<MemorySegment> writer = new AsynchronousBlockWriterWithCallback(channelId, ioMan.getWriteRequestQueue(channelId), new NoOpCallback()) { private int numBlocks; @Override public void writeBlock(MemorySegment segment) throws IOException { numBlocks++; if (numBlocks == failingBlock) { this.requestsNotReturned.incrementAndGet(); this.requestQueue.add(new FailingWriteRequest(this, segment)); } else { super.writeBlock(segment); } } }; try { for (int i = 0; i < numBlocks; i++) { writer.writeBlock(seg); } writer.close(); fail("did not forward exception"); } catch (IOException e) { // expected } finally { try { writer.closeAndDelete(); } catch (Throwable ignored) {} } } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } private static class NoOpCallback implements RequestDoneCallback<MemorySegment> { @Override public void requestSuccessful(MemorySegment buffer) {} @Override public void requestFailed(MemorySegment buffer, IOException e) {} } private static class FailingWriteRequest implements WriteRequest { private final AsynchronousFileIOChannel<MemorySegment, WriteRequest> channel; private final MemorySegment segment; protected FailingWriteRequest(AsynchronousFileIOChannel<MemorySegment, WriteRequest> targetChannel, MemorySegment segment) { this.channel = targetChannel; this.segment = segment; } @Override public void write() throws IOException { throw new IOException(); } @Override public void requestDone(IOException ioex) { this.channel.handleProcessedBuffer(this.segment, ioex); } } private static class TestAsyncFileIOChannel extends AsynchronousFileIOChannel<Buffer, WriteRequest> { protected TestAsyncFileIOChannel( ID channelID, RequestQueue<WriteRequest> requestQueue, RequestDoneCallback<Buffer> callback, boolean writeEnabled) throws IOException { super(channelID, requestQueue, callback, writeEnabled); } int getNumberOfOutstandingRequests() { return requestsNotReturned.get(); } } }