/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.streaming.runtime.io; import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Random; import static org.junit.Assert.*; public class BufferSpillerTest { private static final Logger LOG = LoggerFactory.getLogger(BufferSpillerTest.class); private static final int PAGE_SIZE = 4096; private static IOManager IO_MANAGER; private BufferSpiller spiller; // ------------------------------------------------------------------------ // Setup / Cleanup // ------------------------------------------------------------------------ @BeforeClass public static void setupIOManager() { IO_MANAGER = new IOManagerAsync(); } @AfterClass public static void shutdownIOManager() { IO_MANAGER.shutdown(); } @Before public void createSpiller() { try { spiller = new BufferSpiller(IO_MANAGER, PAGE_SIZE); } catch (Exception e) { e.printStackTrace(); fail("Cannot create BufferSpiller: " + e.getMessage()); } } @After public void cleanupSpiller() { if (spiller != null) { try { spiller.close(); } catch (Exception e) { e.printStackTrace(); fail("Cannot properly close the BufferSpiller: " + e.getMessage()); } assertFalse(spiller.getCurrentChannel().isOpen()); assertFalse(spiller.getCurrentSpillFile().exists()); } checkNoTempFilesRemain(); } // ------------------------------------------------------------------------ // Tests // ------------------------------------------------------------------------ @Test public void testRollOverEmptySequences() { try { assertNull(spiller.rollOver()); assertNull(spiller.rollOver()); assertNull(spiller.rollOver()); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } @Test public void testSpillAndRollOverSimple() { try { final Random rnd = new Random(); final Random bufferRnd = new Random(); final int maxNumEventsAndBuffers = 3000; final int maxNumChannels = 1656; // do multiple spilling / rolling over rounds for (int round = 0; round < 5; round++) { final long bufferSeed = rnd.nextLong(); bufferRnd.setSeed(bufferSeed); final int numEventsAndBuffers = rnd.nextInt(maxNumEventsAndBuffers) + 1; final int numChannels = rnd.nextInt(maxNumChannels) + 1; final ArrayList<BufferOrEvent> events = new ArrayList<BufferOrEvent>(128); // generate sequence for (int i = 0; i < numEventsAndBuffers; i++) { boolean isEvent = rnd.nextDouble() < 0.05d; if (isEvent) { BufferOrEvent evt = generateRandomEvent(rnd, numChannels); events.add(evt); spiller.add(evt); } else { BufferOrEvent evt = generateRandomBuffer(bufferRnd.nextInt(PAGE_SIZE) + 1, bufferRnd.nextInt(numChannels)); spiller.add(evt); } } // reset and create reader bufferRnd.setSeed(bufferSeed); BufferSpiller.SpilledBufferOrEventSequence seq = spiller.rollOver(); seq.open(); // read and validate the sequence int numEvent = 0; for (int i = 0; i < numEventsAndBuffers; i++) { BufferOrEvent next = seq.getNext(); assertNotNull(next); if (next.isEvent()) { BufferOrEvent expected = events.get(numEvent++); assertEquals(expected.getEvent(), next.getEvent()); assertEquals(expected.getChannelIndex(), next.getChannelIndex()); } else { validateBuffer(next, bufferRnd.nextInt(PAGE_SIZE) + 1, bufferRnd.nextInt(numChannels)); } } // no further data assertNull(seq.getNext()); // all events need to be consumed assertEquals(events.size(), numEvent); seq.cleanup(); } } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } @Test public void testSpillWhileReading() { LOG.info("Starting SpillWhileReading test"); try { final int sequences = 10; final Random rnd = new Random(); final int maxNumEventsAndBuffers = 30000; final int maxNumChannels = 1656; int sequencesConsumed = 0; ArrayDeque<SequenceToConsume> pendingSequences = new ArrayDeque<SequenceToConsume>(); SequenceToConsume currentSequence = null; int currentNumEvents = 0; int currentNumRecordAndEvents = 0; // do multiple spilling / rolling over rounds for (int round = 0; round < 2*sequences; round++) { if (round % 2 == 1) { // make this an empty sequence assertNull(spiller.rollOver()); } else { // proper spilled sequence final long bufferSeed = rnd.nextLong(); final Random bufferRnd = new Random(bufferSeed); final int numEventsAndBuffers = rnd.nextInt(maxNumEventsAndBuffers) + 1; final int numChannels = rnd.nextInt(maxNumChannels) + 1; final ArrayList<BufferOrEvent> events = new ArrayList<BufferOrEvent>(128); int generated = 0; while (generated < numEventsAndBuffers) { if (currentSequence == null || rnd.nextDouble() < 0.5) { // add a new record boolean isEvent = rnd.nextDouble() < 0.05; if (isEvent) { BufferOrEvent evt = generateRandomEvent(rnd, numChannels); events.add(evt); spiller.add(evt); } else { BufferOrEvent evt = generateRandomBuffer(bufferRnd.nextInt(PAGE_SIZE) + 1, bufferRnd.nextInt(numChannels)); spiller.add(evt); } generated++; } else { // consume a record BufferOrEvent next = currentSequence.sequence.getNext(); assertNotNull(next); if (next.isEvent()) { BufferOrEvent expected = currentSequence.events.get(currentNumEvents++); assertEquals(expected.getEvent(), next.getEvent()); assertEquals(expected.getChannelIndex(), next.getChannelIndex()); } else { Random validationRnd = currentSequence.bufferRnd; validateBuffer(next, validationRnd.nextInt(PAGE_SIZE) + 1, validationRnd.nextInt(currentSequence.numChannels)); } currentNumRecordAndEvents++; if (currentNumRecordAndEvents == currentSequence.numBuffersAndEvents) { // done with the sequence currentSequence.sequence.cleanup(); sequencesConsumed++; // validate we had all events assertEquals(currentSequence.events.size(), currentNumEvents); // reset currentSequence = pendingSequences.pollFirst(); if (currentSequence != null) { currentSequence.sequence.open(); } currentNumRecordAndEvents = 0; currentNumEvents = 0; } } } // done generating a sequence. queue it for consumption bufferRnd.setSeed(bufferSeed); BufferSpiller.SpilledBufferOrEventSequence seq = spiller.rollOver(); SequenceToConsume stc = new SequenceToConsume(bufferRnd, events, seq, numEventsAndBuffers, numChannels); if (currentSequence == null) { currentSequence = stc; stc.sequence.open(); } else { pendingSequences.addLast(stc); } } } // consume all the remainder while (currentSequence != null) { // consume a record BufferOrEvent next = currentSequence.sequence.getNext(); assertNotNull(next); if (next.isEvent()) { BufferOrEvent expected = currentSequence.events.get(currentNumEvents++); assertEquals(expected.getEvent(), next.getEvent()); assertEquals(expected.getChannelIndex(), next.getChannelIndex()); } else { Random validationRnd = currentSequence.bufferRnd; validateBuffer(next, validationRnd.nextInt(PAGE_SIZE) + 1, validationRnd.nextInt(currentSequence.numChannels)); } currentNumRecordAndEvents++; if (currentNumRecordAndEvents == currentSequence.numBuffersAndEvents) { // done with the sequence currentSequence.sequence.cleanup(); sequencesConsumed++; // validate we had all events assertEquals(currentSequence.events.size(), currentNumEvents); // reset currentSequence = pendingSequences.pollFirst(); if (currentSequence != null) { currentSequence.sequence.open(); } currentNumRecordAndEvents = 0; currentNumEvents = 0; } } assertEquals(sequences, sequencesConsumed); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } /** * Tests that the static HEADER_SIZE field has valid header size. */ @Test public void testHeaderSizeStaticField() throws Exception { int size = 13; BufferOrEvent boe = generateRandomBuffer(size, 0); spiller.add(boe); assertEquals( "Changed the header format, but did not adjust the HEADER_SIZE field", BufferSpiller.HEADER_SIZE + size, spiller.getBytesWritten()); } // ------------------------------------------------------------------------ // Utils // ------------------------------------------------------------------------ private static BufferOrEvent generateRandomEvent(Random rnd, int numChannels) { long magicNumber = rnd.nextLong(); byte[] data = new byte[rnd.nextInt(1000)]; rnd.nextBytes(data); TestEvent evt = new TestEvent(magicNumber, data); int channelIndex = rnd.nextInt(numChannels); return new BufferOrEvent(evt, channelIndex); } private static BufferOrEvent generateRandomBuffer(int size, int channelIndex) { MemorySegment seg = MemorySegmentFactory.allocateUnpooledSegment(PAGE_SIZE); for (int i = 0; i < size; i++) { seg.put(i, (byte) i); } Buffer buf = new Buffer(seg, FreeingBufferRecycler.INSTANCE); buf.setSize(size); return new BufferOrEvent(buf, channelIndex); } private static void validateBuffer(BufferOrEvent boe, int expectedSize, int expectedChannelIndex) { assertEquals("wrong channel index", expectedChannelIndex, boe.getChannelIndex()); assertTrue("is not buffer", boe.isBuffer()); Buffer buf = boe.getBuffer(); assertEquals("wrong buffer size", expectedSize, buf.getSize()); MemorySegment seg = buf.getMemorySegment(); for (int i = 0; i < expectedSize; i++) { byte expected = (byte) i; if (expected != seg.get(i)) { fail(String.format( "wrong buffer contents at position %s : expected=%d , found=%d", i, expected, seg.get(i))); } } } private static void checkNoTempFilesRemain() { // validate that all temp files have been removed for (File dir : IO_MANAGER.getSpillingDirectories()) { for (String file : dir.list()) { if (file != null && !(file.equals(".") || file.equals(".."))) { fail("barrier buffer did not clean up temp files. remaining file: " + file); } } } } private static class SequenceToConsume { final BufferSpiller.SpilledBufferOrEventSequence sequence; final ArrayList<BufferOrEvent> events; final Random bufferRnd; final int numBuffersAndEvents; final int numChannels; private SequenceToConsume(Random bufferRnd, ArrayList<BufferOrEvent> events, BufferSpiller.SpilledBufferOrEventSequence sequence, int numBuffersAndEvents, int numChannels) { this.bufferRnd = bufferRnd; this.events = events; this.sequence = sequence; this.numBuffersAndEvents = numBuffersAndEvents; this.numChannels = numChannels; } } }