/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.runtime.operators.hash; import org.apache.flink.api.common.typeutils.GenericPairComparator; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypePairComparator; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.base.ByteValueSerializer; import org.apache.flink.api.common.typeutils.base.LongComparator; import org.apache.flink.api.common.typeutils.base.LongSerializer; import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArrayComparator; import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerializer; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.runtime.TupleComparator; import org.apache.flink.api.java.typeutils.runtime.TupleSerializer; import org.apache.flink.api.java.typeutils.runtime.ValueComparator; import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync; import org.apache.flink.types.ByteValue; import org.apache.flink.util.MutableObjectIterator; import org.junit.Test; import org.junit.Assert; import org.mockito.Mockito; import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import static org.junit.Assert.*; public class HashTableTest { private final TypeSerializer<Tuple2<Long, byte[]>> buildSerializer; private final TypeSerializer<Long> probeSerializer; private final TypeComparator<Tuple2<Long, byte[]>> buildComparator; private final TypeComparator<Long> probeComparator; private final TypePairComparator<Long, Tuple2<Long, byte[]>> pairComparator; public HashTableTest() { TypeSerializer<?>[] fieldSerializers = { LongSerializer.INSTANCE, BytePrimitiveArraySerializer.INSTANCE }; @SuppressWarnings("unchecked") Class<Tuple2<Long, byte[]>> clazz = (Class<Tuple2<Long, byte[]>>) (Class<?>) Tuple2.class; this.buildSerializer = new TupleSerializer<Tuple2<Long, byte[]>>(clazz, fieldSerializers); this.probeSerializer = LongSerializer.INSTANCE; TypeComparator<?>[] comparators = { new LongComparator(true) }; TypeSerializer<?>[] comparatorSerializers = { LongSerializer.INSTANCE }; this.buildComparator = new TupleComparator<Tuple2<Long, byte[]>>(new int[] {0}, comparators, comparatorSerializers); this.probeComparator = new LongComparator(true); this.pairComparator = new TypePairComparator<Long, Tuple2<Long, byte[]>>() { private long ref; @Override public void setReference(Long reference) { ref = reference; } @Override public boolean equalToReference(Tuple2<Long, byte[]> candidate) { //noinspection UnnecessaryUnboxing return candidate.f0.longValue() == ref; } @Override public int compareToReference(Tuple2<Long, byte[]> candidate) { long x = ref; long y = candidate.f0; return (x < y) ? -1 : ((x == y) ? 0 : 1); } }; } // ------------------------------------------------------------------------ // Tests // ------------------------------------------------------------------------ /** * This tests a combination of values that lead to a corner case situation where memory * was missing and the computation deadlocked. */ @Test public void testBufferMissingForProbing() { final IOManager ioMan = new IOManagerAsync(); try { final int pageSize = 32*1024; final int numSegments = 34; final int numRecords = 3400; final int recordLen = 270; final byte[] payload = new byte[recordLen - 8 - 4]; List<MemorySegment> memory = getMemory(numSegments, pageSize); MutableHashTable<Tuple2<Long, byte[]>, Long> table = new MutableHashTable<>( buildSerializer, probeSerializer, buildComparator, probeComparator, pairComparator, memory, ioMan, 16, false); table.open(new TupleBytesIterator(payload, numRecords), new LongIterator(10000)); try { while (table.nextRecord()) { MutableObjectIterator<Tuple2<Long, byte[]>> matches = table.getBuildSideIterator(); while (matches.next() != null); } } catch (RuntimeException e) { if (!e.getMessage().contains("exceeded maximum number of recursions")) { e.printStackTrace(); fail("Test failed with unexpected exception"); } } finally { table.close(); } checkNoTempFilesRemain(ioMan); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } finally { ioMan.shutdown(); } } /** * This tests the case where no additional partition buffers are used at the point when spilling * is triggered, testing that overflow bucket buffers are taken into account when deciding which * partition to spill. */ @Test public void testSpillingFreesOnlyOverflowSegments() { final IOManager ioMan = new IOManagerAsync(); final TypeSerializer<ByteValue> serializer = ByteValueSerializer.INSTANCE; final TypeComparator<ByteValue> buildComparator = new ValueComparator<>(true, ByteValue.class); final TypeComparator<ByteValue> probeComparator = new ValueComparator<>(true, ByteValue.class); @SuppressWarnings("unchecked") final TypePairComparator<ByteValue, ByteValue> pairComparator = Mockito.mock(TypePairComparator.class); try { final int pageSize = 32*1024; final int numSegments = 34; List<MemorySegment> memory = getMemory(numSegments, pageSize); MutableHashTable<ByteValue, ByteValue> table = new MutableHashTable<>( serializer, serializer, buildComparator, probeComparator, pairComparator, memory, ioMan, 1, false); table.open(new ByteValueIterator(100000000), new ByteValueIterator(1)); table.close(); checkNoTempFilesRemain(ioMan); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } finally { ioMan.shutdown(); } } /** * Tests that the MutableHashTable spills its partitions when creating the initial table * without overflow segments in the partitions. This means that the records are large. */ @Test public void testSpillingWhenBuildingTableWithoutOverflow() throws Exception { final IOManager ioMan = new IOManagerAsync(); final TypeSerializer<byte[]> serializer = BytePrimitiveArraySerializer.INSTANCE; final TypeComparator<byte[]> buildComparator = new BytePrimitiveArrayComparator(true); final TypeComparator<byte[]> probeComparator = new BytePrimitiveArrayComparator(true); @SuppressWarnings("unchecked") final TypePairComparator<byte[], byte[]> pairComparator = new GenericPairComparator<>( new BytePrimitiveArrayComparator(true), new BytePrimitiveArrayComparator(true)); final int pageSize = 128; final int numSegments = 33; List<MemorySegment> memory = getMemory(numSegments, pageSize); MutableHashTable<byte[], byte[]> table = new MutableHashTable<byte[], byte[]>( serializer, serializer, buildComparator, probeComparator, pairComparator, memory, ioMan, 1, false); int numElements = 9; table.open( new CombiningIterator<byte[]>( new ByteArrayIterator(numElements, 128,(byte) 0), new ByteArrayIterator(numElements, 128,(byte) 1)), new CombiningIterator<byte[]>( new ByteArrayIterator(1, 128,(byte) 0), new ByteArrayIterator(1, 128,(byte) 1))); while(table.nextRecord()) { MutableObjectIterator<byte[]> iterator = table.getBuildSideIterator(); int counter = 0; while(iterator.next() != null) { counter++; } // check that we retrieve all our elements Assert.assertEquals(numElements, counter); } table.close(); } // ------------------------------------------------------------------------ // Utilities // ------------------------------------------------------------------------ private static List<MemorySegment> getMemory(int numSegments, int segmentSize) { ArrayList<MemorySegment> list = new ArrayList<MemorySegment>(numSegments); for (int i = 0; i < numSegments; i++) { list.add(MemorySegmentFactory.allocateUnpooledSegment(segmentSize)); } return list; } private static void checkNoTempFilesRemain(IOManager ioManager) { for (File dir : ioManager.getSpillingDirectories()) { for (String file : dir.list()) { if (file != null && !(file.equals(".") || file.equals(".."))) { fail("hash table did not clean up temp files. remaining file: " + file); } } } } private static class TupleBytesIterator implements MutableObjectIterator<Tuple2<Long, byte[]>> { private final byte[] payload; private final int numRecords; private int count = 0; TupleBytesIterator(byte[] payload, int numRecords) { this.payload = payload; this.numRecords = numRecords; } @Override public Tuple2<Long, byte[]> next(Tuple2<Long, byte[]> reuse) { return next(); } @Override public Tuple2<Long, byte[]> next() { if (count++ < numRecords) { return new Tuple2<>(42L, payload); } else { return null; } } } private static class ByteArrayIterator implements MutableObjectIterator<byte[]> { private final long numRecords; private long counter = 0; private final byte[] arrayValue; ByteArrayIterator(long numRecords, int length, byte value) { this.numRecords = numRecords; arrayValue = new byte[length]; Arrays.fill(arrayValue, value); } @Override public byte[] next(byte[] array) { return next(); } @Override public byte[] next() { if (counter++ < numRecords) { return arrayValue; } else { return null; } } } private static class LongIterator implements MutableObjectIterator<Long> { private final long numRecords; private long value = 0; LongIterator(long numRecords) { this.numRecords = numRecords; } @Override public Long next(Long aLong) { return next(); } @Override public Long next() { if (value < numRecords) { return value++; } else { return null; } } } private static class ByteValueIterator implements MutableObjectIterator<ByteValue> { private final long numRecords; private long value = 0; ByteValueIterator(long numRecords) { this.numRecords = numRecords; } @Override public ByteValue next(ByteValue aLong) { return next(); } @Override public ByteValue next() { if (value++ < numRecords) { return new ByteValue((byte) 0); } else { return null; } } } private static class CombiningIterator<T> implements MutableObjectIterator<T> { private final MutableObjectIterator<T> left; private final MutableObjectIterator<T> right; public CombiningIterator(MutableObjectIterator<T> left, MutableObjectIterator<T> right) { this.left = left; this.right = right; } @Override public T next(T reuse) throws IOException { T value = left.next(reuse); if (value == null) { return right.next(reuse); } else { return value; } } @Override public T next() throws IOException { T value = left.next(); if (value == null) { return right.next(); } else { return value; } } } }