/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.runtime.io.network.api.writer; import org.apache.flink.core.io.IOReadableWritable; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.core.memory.MemoryType; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; import org.apache.flink.runtime.io.network.api.EndOfSuperstepEvent; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.api.serialization.RecordSerializer.SerializationResult; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.buffer.BufferProvider; import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; import org.apache.flink.runtime.io.network.util.TestBufferFactory; import org.apache.flink.runtime.io.network.util.TestInfiniteBufferProvider; import org.apache.flink.runtime.io.network.util.TestTaskEvent; import org.apache.flink.runtime.testutils.DiscardingRecycler; import org.apache.flink.types.IntValue; import org.apache.flink.util.XORShiftRandom; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; import java.io.IOException; import java.util.ArrayDeque; import java.util.Queue; import java.util.Random; import java.util.concurrent.Callable; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import static org.apache.flink.util.Preconditions.checkNotNull; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @PrepareForTest({ResultPartitionWriter.class, EventSerializer.class}) @RunWith(PowerMockRunner.class) public class RecordWriterTest { // --------------------------------------------------------------------------------------------- // Resource release tests // --------------------------------------------------------------------------------------------- /** * Tests a fix for FLINK-2089. * * @see <a href="https://issues.apache.org/jira/browse/FLINK-2089">FLINK-2089</a> */ @Test public void testClearBuffersAfterInterruptDuringBlockingBufferRequest() throws Exception { ExecutorService executor = null; try { executor = Executors.newSingleThreadExecutor(); final CountDownLatch sync = new CountDownLatch(2); final Buffer buffer = spy(TestBufferFactory.createBuffer(4)); // Return buffer for first request, but block for all following requests. Answer<Buffer> request = new Answer<Buffer>() { @Override public Buffer answer(InvocationOnMock invocation) throws Throwable { sync.countDown(); if (sync.getCount() == 1) { return buffer; } final Object o = new Object(); synchronized (o) { while (true) { o.wait(); } } } }; BufferProvider bufferProvider = mock(BufferProvider.class); when(bufferProvider.requestBufferBlocking()).thenAnswer(request); ResultPartitionWriter partitionWriter = createResultPartitionWriter(bufferProvider); final RecordWriter<IntValue> recordWriter = new RecordWriter<IntValue>(partitionWriter); Future<?> result = executor.submit(new Callable<Void>() { @Override public Void call() throws Exception { IntValue val = new IntValue(0); try { recordWriter.emit(val); recordWriter.flush(); recordWriter.emit(val); } catch (InterruptedException e) { recordWriter.clearBuffers(); } return null; } }); sync.await(); // Interrupt the Thread. // // The second emit call requests a new buffer and blocks the thread. // When interrupting the thread at this point, clearing the buffers // should not recycle any buffer. result.cancel(true); recordWriter.clearBuffers(); // Verify that buffer have been requested, but only one has been written out. verify(bufferProvider, times(2)).requestBufferBlocking(); verify(partitionWriter, times(1)).writeBuffer(any(Buffer.class), anyInt()); // Verify that the written out buffer has only been recycled once // (by the partition writer). assertTrue("Buffer not recycled.", buffer.isRecycled()); verify(buffer, times(1)).recycle(); } finally { if (executor != null) { executor.shutdown(); } } } @Test public void testClearBuffersAfterExceptionInPartitionWriter() throws Exception { NetworkBufferPool buffers = null; BufferPool bufferPool = null; try { buffers = new NetworkBufferPool(1, 1024, MemoryType.HEAP); bufferPool = spy(buffers.createBufferPool(1, Integer.MAX_VALUE)); ResultPartitionWriter partitionWriter = mock(ResultPartitionWriter.class); when(partitionWriter.getBufferProvider()).thenReturn(checkNotNull(bufferPool)); when(partitionWriter.getNumberOfOutputChannels()).thenReturn(1); // Recycle buffer and throw Exception doAnswer(new Answer<Void>() { @Override public Void answer(InvocationOnMock invocation) throws Throwable { Buffer buffer = (Buffer) invocation.getArguments()[0]; buffer.recycle(); throw new RuntimeException("Expected test Exception"); } }).when(partitionWriter).writeBuffer(any(Buffer.class), anyInt()); RecordWriter<IntValue> recordWriter = new RecordWriter<>(partitionWriter); try { // Verify that emit correctly clears the buffer. The infinite loop looks // dangerous indeed, but the buffer will only be flushed after its full. Adding a // manual flush here doesn't test this case (see next). for (;;) { recordWriter.emit(new IntValue(0)); } } catch (Exception e) { // Verify that the buffer is not part of the record writer state after a failure // to flush it out. If the buffer is still part of the record writer state, this // will fail, because the buffer has already been recycled. NOTE: The mock // partition writer needs to recycle the buffer to correctly test this. recordWriter.clearBuffers(); } // Verify expected methods have been called verify(partitionWriter, times(1)).writeBuffer(any(Buffer.class), anyInt()); verify(bufferPool, times(1)).requestBufferBlocking(); try { // Verify that manual flushing correctly clears the buffer. recordWriter.emit(new IntValue(0)); recordWriter.flush(); Assert.fail("Did not throw expected test Exception"); } catch (Exception e) { recordWriter.clearBuffers(); } // Verify expected methods have been called verify(partitionWriter, times(2)).writeBuffer(any(Buffer.class), anyInt()); verify(bufferPool, times(2)).requestBufferBlocking(); try { // Verify that broadcast emit correctly clears the buffer. for (;;) { recordWriter.broadcastEmit(new IntValue(0)); } } catch (Exception e) { recordWriter.clearBuffers(); } // Verify expected methods have been called verify(partitionWriter, times(3)).writeBuffer(any(Buffer.class), anyInt()); verify(bufferPool, times(3)).requestBufferBlocking(); try { // Verify that end of super step correctly clears the buffer. recordWriter.emit(new IntValue(0)); recordWriter.broadcastEvent(EndOfSuperstepEvent.INSTANCE); Assert.fail("Did not throw expected test Exception"); } catch (Exception e) { recordWriter.clearBuffers(); } // Verify expected methods have been called verify(partitionWriter, times(4)).writeBuffer(any(Buffer.class), anyInt()); verify(bufferPool, times(4)).requestBufferBlocking(); try { // Verify that broadcasting and event correctly clears the buffer. recordWriter.emit(new IntValue(0)); recordWriter.broadcastEvent(new TestTaskEvent()); Assert.fail("Did not throw expected test Exception"); } catch (Exception e) { recordWriter.clearBuffers(); } // Verify expected methods have been called verify(partitionWriter, times(5)).writeBuffer(any(Buffer.class), anyInt()); verify(bufferPool, times(5)).requestBufferBlocking(); } finally { if (bufferPool != null) { assertEquals(1, bufferPool.getNumberOfAvailableMemorySegments()); bufferPool.lazyDestroy(); } if (buffers != null) { assertEquals(1, buffers.getNumberOfAvailableMemorySegments()); buffers.destroy(); } } } @Test public void testSerializerClearedAfterClearBuffers() throws Exception { final Buffer buffer = TestBufferFactory.createBuffer(16); ResultPartitionWriter partitionWriter = createResultPartitionWriter( createBufferProvider(buffer)); RecordWriter<IntValue> recordWriter = new RecordWriter<IntValue>(partitionWriter); // Fill a buffer, but don't write it out. recordWriter.emit(new IntValue(0)); verify(partitionWriter, never()).writeBuffer(any(Buffer.class), anyInt()); // Clear all buffers. recordWriter.clearBuffers(); // This should not throw an Exception iff the serializer state // has been cleared as expected. recordWriter.flush(); } /** * Tests broadcasting events when no records have been emitted yet. */ @Test public void testBroadcastEventNoRecords() throws Exception { int numChannels = 4; int bufferSize = 32; @SuppressWarnings("unchecked") Queue<BufferOrEvent>[] queues = new Queue[numChannels]; for (int i = 0; i < numChannels; i++) { queues[i] = new ArrayDeque<>(); } BufferProvider bufferProvider = createBufferProvider(bufferSize); ResultPartitionWriter partitionWriter = createCollectingPartitionWriter(queues, bufferProvider); RecordWriter<ByteArrayIO> writer = new RecordWriter<>(partitionWriter, new RoundRobin<ByteArrayIO>()); CheckpointBarrier barrier = new CheckpointBarrier(Integer.MAX_VALUE + 919192L, Integer.MAX_VALUE + 18828228L, CheckpointOptions.forFullCheckpoint()); // No records emitted yet, broadcast should not request a buffer writer.broadcastEvent(barrier); verify(bufferProvider, times(0)).requestBufferBlocking(); for (Queue<BufferOrEvent> queue : queues) { assertEquals(1, queue.size()); BufferOrEvent boe = queue.remove(); assertTrue(boe.isEvent()); assertEquals(barrier, boe.getEvent()); } } /** * Tests broadcasting events when records have been emitted. The emitted * records cover all three {@link SerializationResult} types. */ @Test public void testBroadcastEventMixedRecords() throws Exception { Random rand = new XORShiftRandom(); int numChannels = 4; int bufferSize = 32; int lenBytes = 4; // serialized length @SuppressWarnings("unchecked") Queue<BufferOrEvent>[] queues = new Queue[numChannels]; for (int i = 0; i < numChannels; i++) { queues[i] = new ArrayDeque<>(); } BufferProvider bufferProvider = createBufferProvider(bufferSize); ResultPartitionWriter partitionWriter = createCollectingPartitionWriter(queues, bufferProvider); RecordWriter<ByteArrayIO> writer = new RecordWriter<>(partitionWriter, new RoundRobin<ByteArrayIO>()); CheckpointBarrier barrier = new CheckpointBarrier(Integer.MAX_VALUE + 1292L, Integer.MAX_VALUE + 199L, CheckpointOptions.forFullCheckpoint()); // Emit records on some channels first (requesting buffers), then // broadcast the event. The record buffers should be emitted first, then // the event. After the event, no new buffer should be requested. // (i) Smaller than the buffer size (single buffer request => 1) byte[] bytes = new byte[bufferSize / 2]; rand.nextBytes(bytes); writer.emit(new ByteArrayIO(bytes)); // (ii) Larger than the buffer size (two buffer requests => 1 + 2) bytes = new byte[bufferSize + 1]; rand.nextBytes(bytes); writer.emit(new ByteArrayIO(bytes)); // (iii) Exactly the buffer size (single buffer request => 1 + 2 + 1) bytes = new byte[bufferSize - lenBytes]; rand.nextBytes(bytes); writer.emit(new ByteArrayIO(bytes)); // (iv) Nothing on the 4th channel (no buffer request => 1 + 2 + 1 + 0 = 4) // (v) Broadcast the event writer.broadcastEvent(barrier); verify(bufferProvider, times(4)).requestBufferBlocking(); assertEquals(2, queues[0].size()); // 1 buffer + 1 event assertEquals(3, queues[1].size()); // 2 buffers + 1 event assertEquals(2, queues[2].size()); // 1 buffer + 1 event assertEquals(1, queues[3].size()); // 0 buffers + 1 event } /** * Tests that event buffers are properly recycled when broadcasting events * to multiple channels. * * @throws Exception */ @Test public void testBroadcastEventBufferReferenceCounting() throws Exception { Buffer buffer = EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE); // Partial mocking of static method... PowerMockito .stub(PowerMockito.method(EventSerializer.class, "toBuffer")) .toReturn(buffer); @SuppressWarnings("unchecked") ArrayDeque<BufferOrEvent>[] queues = new ArrayDeque[]{new ArrayDeque(), new ArrayDeque()}; ResultPartitionWriter partition = createCollectingPartitionWriter(queues, new TestInfiniteBufferProvider()); RecordWriter<?> writer = new RecordWriter<>(partition); writer.broadcastEvent(EndOfPartitionEvent.INSTANCE); // Verify added to all queues assertEquals(1, queues[0].size()); assertEquals(1, queues[1].size()); assertTrue(buffer.isRecycled()); } // --------------------------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------------------------- /** * Creates a mock partition writer that collects the added buffers/events. * * <p>This much mocking should not be necessary with better designed * interfaces. Refactoring this will take too much time now though, hence * the mocking. Ideally, we will refactor all of this mess in order to make * our lives easier and test it better. */ private ResultPartitionWriter createCollectingPartitionWriter( final Queue<BufferOrEvent>[] queues, BufferProvider bufferProvider) throws IOException { int numChannels = queues.length; ResultPartitionWriter partitionWriter = mock(ResultPartitionWriter.class); when(partitionWriter.getBufferProvider()).thenReturn(checkNotNull(bufferProvider)); when(partitionWriter.getNumberOfOutputChannels()).thenReturn(numChannels); doAnswer(new Answer<Void>() { @Override public Void answer(InvocationOnMock invocationOnMock) throws Throwable { Buffer buffer = (Buffer) invocationOnMock.getArguments()[0]; if (buffer.isBuffer()) { Integer targetChannel = (Integer) invocationOnMock.getArguments()[1]; queues[targetChannel].add(new BufferOrEvent(buffer, targetChannel)); } else { // is event: AbstractEvent event = EventSerializer.fromBuffer(buffer, getClass().getClassLoader()); buffer.recycle(); // the buffer is not needed anymore Integer targetChannel = (Integer) invocationOnMock.getArguments()[1]; queues[targetChannel].add(new BufferOrEvent(event, targetChannel)); } return null; } }).when(partitionWriter).writeBuffer(any(Buffer.class), anyInt()); return partitionWriter; } private BufferProvider createBufferProvider(final int bufferSize) throws IOException, InterruptedException { BufferProvider bufferProvider = mock(BufferProvider.class); when(bufferProvider.requestBufferBlocking()).thenAnswer( new Answer<Buffer>() { @Override public Buffer answer(InvocationOnMock invocationOnMock) throws Throwable { MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(bufferSize); Buffer buffer = new Buffer(segment, DiscardingRecycler.INSTANCE); return buffer; } } ); return bufferProvider; } private BufferProvider createBufferProvider(Buffer... buffers) throws IOException, InterruptedException { BufferProvider bufferProvider = mock(BufferProvider.class); for (int i = 0; i < buffers.length; i++) { when(bufferProvider.requestBufferBlocking()).thenReturn(buffers[i]); } return bufferProvider; } private ResultPartitionWriter createResultPartitionWriter(BufferProvider bufferProvider) throws IOException { ResultPartitionWriter partitionWriter = mock(ResultPartitionWriter.class); when(partitionWriter.getBufferProvider()).thenReturn(checkNotNull(bufferProvider)); when(partitionWriter.getNumberOfOutputChannels()).thenReturn(1); // Recycle each written buffer. doAnswer(new Answer<Void>() { @Override public Void answer(InvocationOnMock invocation) throws Throwable { ((Buffer) invocation.getArguments()[0]).recycle(); return null; } }).when(partitionWriter).writeBuffer(any(Buffer.class), anyInt()); return partitionWriter; } private static class ByteArrayIO implements IOReadableWritable { private final byte[] bytes; public ByteArrayIO(byte[] bytes) { this.bytes = bytes; } @Override public void write(DataOutputView out) throws IOException { out.write(bytes); } @Override public void read(DataInputView in) throws IOException { in.readFully(bytes); } } /** * RoundRobin channel selector starting at 0 ({@link RoundRobinChannelSelector} starts at 1). */ private static class RoundRobin<T extends IOReadableWritable> implements ChannelSelector<T> { private int[] nextChannel = new int[] { -1 }; @Override public int[] selectChannels(final T record, final int numberOfOutputChannels) { nextChannel[0] = (nextChannel[0] + 1) % numberOfOutputChannels; return nextChannel; } } }