/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.runtime.operators.hash; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Random; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypePairComparator; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.base.LongComparator; import org.apache.flink.api.common.typeutils.base.LongSerializer; import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.runtime.TupleComparator; import org.apache.flink.api.java.typeutils.runtime.TupleSerializer; import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.runtime.operators.testutils.UniformStringPairGenerator; import org.apache.flink.runtime.operators.testutils.types.IntList; import org.apache.flink.runtime.operators.testutils.types.IntListComparator; import org.apache.flink.runtime.operators.testutils.types.IntListPairComparator; import org.apache.flink.runtime.operators.testutils.types.IntListSerializer; import org.apache.flink.runtime.operators.testutils.types.IntPair; import org.apache.flink.runtime.operators.testutils.types.IntPairComparator; import org.apache.flink.runtime.operators.testutils.types.IntPairListPairComparator; import org.apache.flink.runtime.operators.testutils.types.IntPairPairComparator; import org.apache.flink.runtime.operators.testutils.types.IntPairSerializer; import org.apache.flink.runtime.operators.testutils.types.StringPair; import org.apache.flink.runtime.operators.testutils.types.StringPairComparator; import org.apache.flink.runtime.operators.testutils.types.StringPairPairComparator; import org.apache.flink.runtime.operators.testutils.types.StringPairSerializer; import org.apache.flink.util.MutableObjectIterator; import org.junit.Test; import static org.junit.Assert.*; public abstract class MutableHashTableTestBase { protected static final long RANDOM_SEED = 76518743207143L; private static final int KEY_VALUE_DIFF = 1021; public final int PAGE_SIZE = 16 * 1024; public final TypeSerializer<IntPair> intPairSerializer = new IntPairSerializer(); public final TypeComparator<IntPair> intPairComparator = new IntPairComparator(); public final TypePairComparator<IntPair, IntPair> pairComparator = new IntPairPairComparator(); private static final int MAX_LIST_SIZE = 8; public final TypeSerializer<IntList> serializerV = new IntListSerializer(); public final TypeComparator<IntList> comparatorV = new IntListComparator(); public final TypePairComparator<IntList, IntList> pairComparatorV = new IntListPairComparator(); public final TypePairComparator<IntPair, IntList> pairComparatorPL = new IntPairListPairComparator(); @SuppressWarnings("unchecked") protected TupleSerializer<Tuple2<Long, String>> tuple2LongStringSerializer = new TupleSerializer<>( (Class<Tuple2<Long, String>>) (Class<?>) Tuple2.class, new TypeSerializer<?>[] { LongSerializer.INSTANCE, StringSerializer.INSTANCE }); protected TupleComparator<Tuple2<Long, String>> tuple2LongStringComparator = new TupleComparator<>( new int[] {0}, new TypeComparator<?>[] { new LongComparator(true) }, new TypeSerializer<?>[] { LongSerializer.INSTANCE }); public final int SIZE = 75; public final int NUM_PAIRS = 100000; public final int NUM_LISTS = 100000; protected final int ADDITIONAL_MEM = 100; private final int NUM_REWRITES = 10; public final TypeSerializer<StringPair> serializerS = new StringPairSerializer(); public final TypeComparator<StringPair> comparatorS = new StringPairComparator(); private final TypePairComparator<StringPair, StringPair> pairComparatorS = new StringPairPairComparator(); abstract protected <T> AbstractMutableHashTable<T> getHashTable( TypeSerializer<T> serializer, TypeComparator<T> comparator, List<MemorySegment> memory); @Test public void testDifferentProbers() { final int NUM_MEM_PAGES = 32 * NUM_PAIRS / PAGE_SIZE; AbstractMutableHashTable<IntPair> table = getHashTable(intPairSerializer, intPairComparator, getMemory(NUM_MEM_PAGES)); AbstractHashTableProber<IntPair, IntPair> prober1 = table.getProber(intPairComparator, pairComparator); AbstractHashTableProber<IntPair, IntPair> prober2 = table.getProber(intPairComparator, pairComparator); assertFalse(prober1 == prober2); table.close(); // (This also tests calling close without calling open first.) assertEquals("Memory lost", NUM_MEM_PAGES, table.getFreeMemory().size()); } @Test public void testBuildAndRetrieve() throws Exception { final int NUM_MEM_PAGES = 32 * NUM_PAIRS / PAGE_SIZE; AbstractMutableHashTable<IntPair> table = getHashTable(intPairSerializer, intPairComparator, getMemory(NUM_MEM_PAGES)); final Random rnd = new Random(RANDOM_SEED); final IntPair[] pairs = getRandomizedIntPairs(NUM_PAIRS, rnd); table.open(); for (int i = 0; i < NUM_PAIRS; i++) { table.insert(pairs[i]); } AbstractHashTableProber<IntPair, IntPair> prober = table.getProber(intPairComparator, pairComparator); IntPair target = new IntPair(); for (int i = 0; i < NUM_PAIRS; i++) { assertNotNull(prober.getMatchFor(pairs[i], target)); assertEquals(pairs[i].getValue(), target.getValue()); } table.close(); assertEquals("Memory lost", NUM_MEM_PAGES, table.getFreeMemory().size()); } @Test public void testEntryIterator() throws Exception { final int NUM_MEM_PAGES = SIZE * NUM_LISTS / PAGE_SIZE; AbstractMutableHashTable<IntList> table = getHashTable(serializerV, comparatorV, getMemory(NUM_MEM_PAGES)); final Random rnd = new Random(RANDOM_SEED); final IntList[] lists = getRandomizedIntLists(NUM_LISTS, rnd); table.open(); int result = 0; for (int i = 0; i < NUM_LISTS; i++) { table.insert(lists[i]); result += lists[i].getKey(); } MutableObjectIterator<IntList> iter = table.getEntryIterator(); IntList target = new IntList(); int sum = 0; while((target = iter.next(target)) != null) { sum += target.getKey(); } table.close(); assertTrue(sum == result); assertEquals("Memory lost", NUM_MEM_PAGES, table.getFreeMemory().size()); } @Test public void testMultipleProbers() throws Exception { final int NUM_MEM_PAGES = SIZE * NUM_LISTS / PAGE_SIZE; AbstractMutableHashTable<IntList> table = getHashTable(serializerV, comparatorV, getMemory(NUM_MEM_PAGES)); final Random rnd = new Random(RANDOM_SEED); final IntList[] lists = getRandomizedIntLists(NUM_LISTS, rnd); final IntPair[] pairs = getRandomizedIntPairs(NUM_LISTS, rnd); table.open(); for (int i = 0; i < NUM_LISTS; i++) { table.insert(lists[i]); } AbstractHashTableProber<IntList, IntList> listProber = table.getProber(comparatorV, pairComparatorV); AbstractHashTableProber<IntPair, IntList> pairProber = table.getProber(intPairComparator, pairComparatorPL); IntList target = new IntList(); for (int i = 0; i < NUM_LISTS; i++) { assertNotNull(pairProber.getMatchFor(pairs[i], target)); assertNotNull(listProber.getMatchFor(lists[i], target)); assertArrayEquals(lists[i].getValue(), target.getValue()); } table.close(); assertEquals("Memory lost", NUM_MEM_PAGES, table.getFreeMemory().size()); } @Test public void testVariableLengthBuildAndRetrieve() throws Exception { final int NUM_MEM_PAGES = SIZE * NUM_LISTS / PAGE_SIZE; AbstractMutableHashTable<IntList> table = getHashTable(serializerV, comparatorV, getMemory(NUM_MEM_PAGES)); final Random rnd = new Random(RANDOM_SEED); final IntList[] lists = getRandomizedIntLists(NUM_LISTS, rnd); table.open(); for (int i = 0; i < NUM_LISTS; i++) { try { table.insert(lists[i]); } catch (Exception e) { throw e; } } AbstractHashTableProber<IntList, IntList> prober = table.getProber(comparatorV, pairComparatorV); IntList target = new IntList(); for (int i = 0; i < NUM_LISTS; i++) { assertNotNull(prober.getMatchFor(lists[i], target)); assertArrayEquals(lists[i].getValue(), target.getValue()); } final IntList[] overwriteLists = getRandomizedIntLists(NUM_LISTS, rnd); // test replacing for (int i = 0; i < NUM_LISTS; i++) { table.insertOrReplaceRecord(overwriteLists[i]); } for (int i = 0; i < NUM_LISTS; i++) { assertNotNull("" + i, prober.getMatchFor(overwriteLists[i], target)); assertArrayEquals(overwriteLists[i].getValue(), target.getValue()); } table.close(); assertEquals("Memory lost", NUM_MEM_PAGES, table.getFreeMemory().size()); } @Test public void testVariableLengthBuildAndRetrieveMajorityUpdated() throws Exception { final int NUM_MEM_PAGES = SIZE * NUM_LISTS / PAGE_SIZE; AbstractMutableHashTable<IntList> table = getHashTable(serializerV, comparatorV, getMemory(NUM_MEM_PAGES)); final Random rnd = new Random(RANDOM_SEED); final IntList[] lists = getRandomizedIntLists(NUM_LISTS, rnd); table.open(); for (int i = 0; i < NUM_LISTS; i++) { table.insert(lists[i]); } AbstractHashTableProber<IntList, IntList> prober = table.getProber(comparatorV, pairComparatorV); IntList target = new IntList(); for (int i = 0; i < NUM_LISTS; i++) { assertNotNull(prober.getMatchFor(lists[i], target)); assertArrayEquals(lists[i].getValue(), target.getValue()); } final IntList[] overwriteLists = getRandomizedIntLists(NUM_LISTS, rnd); // test replacing for (int i = 0; i < NUM_LISTS; i++) { if( i % 100 != 0) { table.insertOrReplaceRecord(overwriteLists[i]); lists[i] = overwriteLists[i]; } } for (int i = 0; i < NUM_LISTS; i++) { assertNotNull("" + i, prober.getMatchFor(lists[i], target)); assertArrayEquals(lists[i].getValue(), target.getValue()); } table.close(); assertEquals("Memory lost", NUM_MEM_PAGES, table.getFreeMemory().size()); } @Test public void testVariableLengthBuildAndRetrieveMinorityUpdated() throws Exception { final int NUM_LISTS = 20000; final int NUM_MEM_PAGES = SIZE * NUM_LISTS / PAGE_SIZE; AbstractMutableHashTable<IntList> table = getHashTable(serializerV, comparatorV, getMemory(NUM_MEM_PAGES)); final int STEP_SIZE = 100; final Random rnd = new Random(RANDOM_SEED); final IntList[] lists = getRandomizedIntLists(NUM_LISTS, rnd); table.open(); for (int i = 0; i < NUM_LISTS; i++) { table.insert(lists[i]); } AbstractHashTableProber<IntList, IntList> prober = table.getProber(comparatorV, pairComparatorV); IntList target = new IntList(); for (int i = 0; i < NUM_LISTS; i++) { assertNotNull(prober.getMatchFor(lists[i], target)); assertArrayEquals(lists[i].getValue(), target.getValue()); } final IntList[] overwriteLists = getRandomizedIntLists(NUM_LISTS/STEP_SIZE, rnd); // test replacing for (int i = 0; i < NUM_LISTS; i += STEP_SIZE) { overwriteLists[i/STEP_SIZE].setKey(overwriteLists[i/STEP_SIZE].getKey()*STEP_SIZE); table.insertOrReplaceRecord(overwriteLists[i/STEP_SIZE]); lists[i] = overwriteLists[i/STEP_SIZE]; } for (int i = 0; i < NUM_LISTS; i++) { assertNotNull(prober.getMatchFor(lists[i], target)); assertArrayEquals(lists[i].getValue(), target.getValue()); } table.close(); assertEquals("Memory lost", NUM_MEM_PAGES, table.getFreeMemory().size()); } @Test public void testRepeatedBuildAndRetrieve() throws Exception { final int NUM_MEM_PAGES = SIZE * NUM_LISTS / PAGE_SIZE; AbstractMutableHashTable<IntList> table = getHashTable(serializerV, comparatorV, getMemory(NUM_MEM_PAGES)); final Random rnd = new Random(RANDOM_SEED); final IntList[] lists = getRandomizedIntLists(NUM_LISTS, rnd); table.open(); for (int i = 0; i < NUM_LISTS; i++) { try { table.insert(lists[i]); } catch (Exception e) { throw e; } } AbstractHashTableProber<IntList, IntList> prober = table.getProber(comparatorV, pairComparatorV); IntList target = new IntList(); for (int i = 0; i < NUM_LISTS; i++) { assertNotNull(prober.getMatchFor(lists[i], target)); assertArrayEquals(lists[i].getValue(), target.getValue()); } IntList[] overwriteLists; for(int k = 0; k < NUM_REWRITES; k++) { overwriteLists = getRandomizedIntLists(NUM_LISTS, rnd); // test replacing for (int i = 0; i < NUM_LISTS; i++) { table.insertOrReplaceRecord(overwriteLists[i]); } for (int i = 0; i < NUM_LISTS; i++) { assertNotNull("" + i, prober.getMatchFor(overwriteLists[i], target)); assertArrayEquals(overwriteLists[i].getValue(), target.getValue()); } } table.close(); assertEquals("Memory lost", NUM_MEM_PAGES, table.getFreeMemory().size()); } @Test public void testProberUpdate() throws Exception { final int NUM_MEM_PAGES = SIZE * NUM_LISTS / PAGE_SIZE; AbstractMutableHashTable<IntList> table = getHashTable(serializerV, comparatorV, getMemory(NUM_MEM_PAGES)); final Random rnd = new Random(RANDOM_SEED); final IntList[] lists = getRandomizedIntLists(NUM_LISTS, rnd); table.open(); for (int i = 0; i < NUM_LISTS; i++) { table.insert(lists[i]); } final IntList[] overwriteLists = getRandomizedIntLists(NUM_LISTS, rnd); AbstractHashTableProber<IntList, IntList> prober = table.getProber(comparatorV, pairComparatorV); IntList target = new IntList(); for (int i = 0; i < NUM_LISTS; i++) { assertNotNull(""+i,prober.getMatchFor(lists[i], target)); assertArrayEquals(lists[i].getValue(), target.getValue()); prober.updateMatch(overwriteLists[i]); } for (int i = 0; i < NUM_LISTS; i++) { assertNotNull("" + i, prober.getMatchFor(overwriteLists[i], target)); assertArrayEquals(overwriteLists[i].getValue(), target.getValue()); } table.close(); assertEquals("Memory lost", NUM_MEM_PAGES, table.getFreeMemory().size()); } @Test public void testVariableLengthStringBuildAndRetrieve() throws IOException { final int NUM_MEM_PAGES = 40 * NUM_PAIRS / PAGE_SIZE; AbstractMutableHashTable<StringPair> table = getHashTable(serializerS, comparatorS, getMemory(NUM_MEM_PAGES)); MutableObjectIterator<StringPair> buildInput = new UniformStringPairGenerator(NUM_PAIRS, 1, false); MutableObjectIterator<StringPair> probeTester = new UniformStringPairGenerator(NUM_PAIRS, 1, false); MutableObjectIterator<StringPair> updater = new UniformStringPairGenerator(NUM_PAIRS, 1, false); MutableObjectIterator<StringPair> updateTester = new UniformStringPairGenerator(NUM_PAIRS, 1, false); table.open(); StringPair target = new StringPair(); while(buildInput.next(target) != null) { table.insert(target); } AbstractHashTableProber<StringPair, StringPair> prober = table.getProber(comparatorS, pairComparatorS); StringPair temp = new StringPair(); while(probeTester.next(target) != null) { assertNotNull("" + target.getKey(), prober.getMatchFor(target, temp)); assertEquals(temp.getValue(), target.getValue()); } while(updater.next(target) != null) { target.setValue(target.getValue()); table.insertOrReplaceRecord(target); } while (updateTester.next(target) != null) { assertNotNull(prober.getMatchFor(target, temp)); assertEquals(target.getValue(), temp.getValue()); } table.close(); assertEquals("Memory lost", NUM_MEM_PAGES, table.getFreeMemory().size()); } protected static IntPair[] getRandomizedIntPairs(int num, Random rnd) { IntPair[] pairs = new IntPair[num]; // create all the pairs, dense for (int i = 0; i < num; i++) { pairs[i] = new IntPair(i, i + KEY_VALUE_DIFF); } // randomly swap them for (int i = 0; i < 2 * num; i++) { int pos1 = rnd.nextInt(num); int pos2 = rnd.nextInt(num); IntPair tmp = pairs[pos1]; pairs[pos1] = pairs[pos2]; pairs[pos2] = tmp; } return pairs; } protected static IntList[] getRandomizedIntLists(int num, Random rnd) { IntList[] lists = new IntList[num]; for (int i = 0; i < num; i++) { int[] value = new int[rnd.nextInt(MAX_LIST_SIZE)+1]; //int[] value = new int[MAX_LIST_SIZE-1]; for (int j = 0; j < value.length; j++) { value[j] = -rnd.nextInt(Integer.MAX_VALUE); } lists[i] = new IntList(i, value); } return lists; } public List<MemorySegment> getMemory(int numPages) { return getMemory(numPages, PAGE_SIZE); } private static List<MemorySegment> getMemory(int numSegments, int segmentSize) { ArrayList<MemorySegment> list = new ArrayList<MemorySegment>(numSegments); for (int i = 0; i < numSegments; i++) { list.add(MemorySegmentFactory.allocateUnpooledSegment(segmentSize)); } return list; } }