/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator; import com.facebook.presto.array.IntBigArray; import com.facebook.presto.array.LongBigArray; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.type.BigintOperators; import com.google.common.collect.ImmutableList; import org.openjdk.jol.info.ClassLayout; import java.util.List; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INSUFFICIENT_RESOURCES; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.type.TypeUtils.NULL_HASH_CODE; import static com.facebook.presto.util.HashCollisionsEstimator.estimateNumberOfHashCollisions; import static com.google.common.base.Preconditions.checkArgument; import static it.unimi.dsi.fastutil.HashCommon.arraySize; import static it.unimi.dsi.fastutil.HashCommon.murmurHash3; public class BigintGroupByHash implements GroupByHash { private static final int INSTANCE_SIZE = ClassLayout.parseClass(BigintGroupByHash.class).instanceSize(); private static final float FILL_RATIO = 0.75f; private static final List<Type> TYPES = ImmutableList.of(BIGINT); private static final List<Type> TYPES_WITH_RAW_HASH = ImmutableList.of(BIGINT, BIGINT); private final int hashChannel; private final boolean outputRawHash; private int hashCapacity; private int maxFill; private int mask; // the hash table from values to groupIds private LongBigArray values; private IntBigArray groupIds; // groupId for the null value private int nullGroupId = -1; // reverse index from the groupId back to the value private final LongBigArray valuesByGroupId; private int nextGroupId; private long hashCollisions; private double expectedHashCollisions; public BigintGroupByHash(int hashChannel, boolean outputRawHash, int expectedSize) { checkArgument(hashChannel >= 0, "hashChannel must be at least zero"); checkArgument(expectedSize > 0, "expectedSize must be greater than zero"); this.hashChannel = hashChannel; this.outputRawHash = outputRawHash; hashCapacity = arraySize(expectedSize, FILL_RATIO); maxFill = calculateMaxFill(hashCapacity); mask = hashCapacity - 1; values = new LongBigArray(); values.ensureCapacity(hashCapacity); groupIds = new IntBigArray(-1); groupIds.ensureCapacity(hashCapacity); valuesByGroupId = new LongBigArray(); valuesByGroupId.ensureCapacity(hashCapacity); } @Override public long getEstimatedSize() { return INSTANCE_SIZE + groupIds.sizeOf() + values.sizeOf() + valuesByGroupId.sizeOf(); } @Override public long getHashCollisions() { return hashCollisions; } @Override public double getExpectedHashCollisions() { return expectedHashCollisions + estimateNumberOfHashCollisions(getGroupCount(), hashCapacity); } @Override public List<Type> getTypes() { return outputRawHash ? TYPES_WITH_RAW_HASH : TYPES; } @Override public int getGroupCount() { return nextGroupId; } @Override public void appendValuesTo(int groupId, PageBuilder pageBuilder, int outputChannelOffset) { checkArgument(groupId >= 0, "groupId is negative"); BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(outputChannelOffset); if (groupId == nullGroupId) { blockBuilder.appendNull(); } else { BIGINT.writeLong(blockBuilder, valuesByGroupId.get(groupId)); } if (outputRawHash) { BlockBuilder hashBlockBuilder = pageBuilder.getBlockBuilder(outputChannelOffset + 1); if (groupId == nullGroupId) { BIGINT.writeLong(hashBlockBuilder, NULL_HASH_CODE); } else { BIGINT.writeLong(hashBlockBuilder, BigintOperators.hashCode(valuesByGroupId.get(groupId))); } } } @Override public void addPage(Page page) { int positionCount = page.getPositionCount(); // get the group id for each position Block block = page.getBlock(hashChannel); for (int position = 0; position < positionCount; position++) { // get the group for the current row putIfAbsent(position, block); } } @Override public GroupByIdBlock getGroupIds(Page page) { int positionCount = page.getPositionCount(); // we know the exact size required for the block BlockBuilder blockBuilder = BIGINT.createFixedSizeBlockBuilder(positionCount); // get the group id for each position Block block = page.getBlock(hashChannel); for (int position = 0; position < positionCount; position++) { // get the group for the current row int groupId = putIfAbsent(position, block); // output the group id for this row BIGINT.writeLong(blockBuilder, groupId); } return new GroupByIdBlock(nextGroupId, blockBuilder.build()); } @Override public boolean contains(int position, Page page, int[] hashChannels) { Block block = page.getBlock(hashChannel); if (block.isNull(position)) { return nullGroupId >= 0; } long value = BIGINT.getLong(block, position); long hashPosition = getHashPosition(value, mask); // look for an empty slot or a slot containing this key while (true) { int groupId = groupIds.get(hashPosition); if (groupId == -1) { return false; } else if (value == values.get(hashPosition)) { return true; } // increment position and mask to handle wrap around hashPosition = (hashPosition + 1) & mask; } } @Override public int putIfAbsent(int position, Page page) { Block block = page.getBlock(hashChannel); return putIfAbsent(position, block); } @Override public long getRawHash(int groupId) { return BigintType.hash(valuesByGroupId.get(groupId)); } private int putIfAbsent(int position, Block block) { if (block.isNull(position)) { if (nullGroupId < 0) { // set null group id nullGroupId = nextGroupId++; } return nullGroupId; } long value = BIGINT.getLong(block, position); long hashPosition = getHashPosition(value, mask); // look for an empty slot or a slot containing this key while (true) { int groupId = groupIds.get(hashPosition); if (groupId == -1) { break; } if (value == values.get(hashPosition)) { return groupId; } // increment position and mask to handle wrap around hashPosition = (hashPosition + 1) & mask; hashCollisions++; } return addNewGroup(hashPosition, value); } private int addNewGroup(long hashPosition, long value) { // record group id in hash int groupId = nextGroupId++; values.set(hashPosition, value); valuesByGroupId.set(groupId, value); groupIds.set(hashPosition, groupId); // increase capacity, if necessary if (nextGroupId >= maxFill) { rehash(); } return groupId; } private void rehash() { expectedHashCollisions += estimateNumberOfHashCollisions(getGroupCount(), hashCapacity); long newCapacityLong = hashCapacity * 2L; if (newCapacityLong > Integer.MAX_VALUE) { throw new PrestoException(GENERIC_INSUFFICIENT_RESOURCES, "Size of hash table cannot exceed 1 billion entries"); } int newCapacity = (int) newCapacityLong; int newMask = newCapacity - 1; LongBigArray newValues = new LongBigArray(); newValues.ensureCapacity(newCapacity); IntBigArray newGroupIds = new IntBigArray(-1); newGroupIds.ensureCapacity(newCapacity); for (int groupId = 0; groupId < nextGroupId; groupId++) { if (groupId == nullGroupId) { continue; } long value = valuesByGroupId.get(groupId); // find an empty slot for the address long hashPosition = getHashPosition(value, newMask); while (newGroupIds.get(hashPosition) != -1) { hashPosition = (hashPosition + 1) & newMask; hashCollisions++; } // record the mapping newValues.set(hashPosition, value); newGroupIds.set(hashPosition, groupId); } mask = newMask; hashCapacity = newCapacity; maxFill = calculateMaxFill(hashCapacity); values = newValues; groupIds = newGroupIds; this.valuesByGroupId.ensureCapacity(maxFill); } private static long getHashPosition(long rawHash, int mask) { return murmurHash3(rawHash) & mask; } private static int calculateMaxFill(int hashSize) { checkArgument(hashSize > 0, "hashSize must be greater than 0"); int maxFill = (int) Math.ceil(hashSize * FILL_RATIO); if (maxFill == hashSize) { maxFill--; } checkArgument(hashSize > maxFill, "hashSize must be larger than maxFill"); return maxFill; } }