/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator.aggregation; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.type.Type; import io.airlift.units.DataSize; import it.unimi.dsi.fastutil.ints.IntArrayList; import org.openjdk.jol.info.ClassLayout; import static com.facebook.presto.ExceededMemoryLimitException.exceededLocalLimit; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INSUFFICIENT_RESOURCES; import static com.facebook.presto.type.TypeUtils.hashPosition; import static com.facebook.presto.type.TypeUtils.positionEqualsPosition; import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.units.DataSize.Unit.MEGABYTE; import static it.unimi.dsi.fastutil.HashCommon.arraySize; import static java.util.Objects.requireNonNull; public class TypedSet { private static final int INSTANCE_SIZE = ClassLayout.parseClass(TypedSet.class).instanceSize(); private static final int INT_ARRAY_LIST_INSTANCE_SIZE = ClassLayout.parseClass(IntArrayList.class).instanceSize(); private static final float FILL_RATIO = 0.75f; private static final long FOUR_MEGABYTES = new DataSize(4, MEGABYTE).toBytes(); private final Type elementType; private final IntArrayList blockPositionByHash; private final BlockBuilder elementBlock; private int hashCapacity; private int maxFill; private int hashMask; private static final int EMPTY_SLOT = -1; private boolean containsNullElement; public TypedSet(Type elementType, int expectedSize) { checkArgument(expectedSize >= 0, "expectedSize must not be negative"); this.elementType = requireNonNull(elementType, "elementType must not be null"); this.elementBlock = elementType.createBlockBuilder(new BlockBuilderStatus(), expectedSize); hashCapacity = arraySize(expectedSize, FILL_RATIO); this.maxFill = calculateMaxFill(hashCapacity); this.hashMask = hashCapacity - 1; blockPositionByHash = new IntArrayList(hashCapacity); blockPositionByHash.size(hashCapacity); for (int i = 0; i < hashCapacity; i++) { blockPositionByHash.set(i, EMPTY_SLOT); } this.containsNullElement = false; } public long getRetainedSizeInBytes() { return INSTANCE_SIZE + INT_ARRAY_LIST_INSTANCE_SIZE + elementBlock.getRetainedSizeInBytes() + blockPositionByHash.size() * Integer.BYTES; } public boolean contains(Block block, int position) { requireNonNull(block, "block must not be null"); checkArgument(position >= 0, "position must be >= 0"); if (block.isNull(position)) { return containsNullElement; } else { return blockPositionByHash.get(getHashPositionOfElement(block, position)) != EMPTY_SLOT; } } public void add(Block block, int position) { requireNonNull(block, "block must not be null"); checkArgument(position >= 0, "position must be >= 0"); if (block.isNull(position)) { containsNullElement = true; } else { int hashPosition = getHashPositionOfElement(block, position); if (blockPositionByHash.get(hashPosition) == EMPTY_SLOT) { addNewElement(hashPosition, block, position); } } } public int size() { return elementBlock.getPositionCount() + (containsNullElement ? 1 : 0); } public int positionOf(Block block, int position) { return blockPositionByHash.get(getHashPositionOfElement(block, position)); } /** * Get slot position of element at {@code position} of {@code block} */ private int getHashPositionOfElement(Block block, int position) { int hashPosition = getMaskedHash(hashPosition(elementType, block, position)); while (true) { int blockPosition = blockPositionByHash.get(hashPosition); // Doesn't have this element if (blockPosition == EMPTY_SLOT) { return hashPosition; } // Already has this element else if (positionEqualsPosition(elementType, elementBlock, blockPosition, block, position)) { return hashPosition; } hashPosition = getMaskedHash(hashPosition + 1); } } private void addNewElement(int hashPosition, Block block, int position) { elementType.appendTo(block, position, elementBlock); if (elementBlock.getSizeInBytes() > FOUR_MEGABYTES) { throw exceededLocalLimit(new DataSize(4, MEGABYTE)); } blockPositionByHash.set(hashPosition, elementBlock.getPositionCount() - 1); // increase capacity, if necessary if (elementBlock.getPositionCount() >= maxFill) { rehash(); } } private void rehash() { long newCapacityLong = hashCapacity * 2L; if (newCapacityLong > Integer.MAX_VALUE) { throw new PrestoException(GENERIC_INSUFFICIENT_RESOURCES, "Size of hash table cannot exceed 1 billion entries"); } int newCapacity = (int) newCapacityLong; hashCapacity = newCapacity; hashMask = newCapacity - 1; maxFill = calculateMaxFill(newCapacity); blockPositionByHash.size(newCapacity); for (int i = 0; i < newCapacity; i++) { blockPositionByHash.set(i, EMPTY_SLOT); } rehashBlock(elementBlock); } private void rehashBlock(Block block) { for (int blockPosition = 0; blockPosition < block.getPositionCount(); blockPosition++) { blockPositionByHash.set(getHashPositionOfElement(block, blockPosition), blockPosition); } } private static int calculateMaxFill(int hashSize) { checkArgument(hashSize > 0, "hashSize must be greater than 0"); int maxFill = (int) Math.ceil(hashSize * FILL_RATIO); if (maxFill == hashSize) { maxFill--; } checkArgument(hashSize > maxFill, "hashSize must be larger than maxFill"); return maxFill; } private int getMaskedHash(long rawHash) { return (int) (rawHash & hashMask); } }