/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator; import com.facebook.presto.array.LongBigArray; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.DictionaryBlock; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.JoinCompiler; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import it.unimi.dsi.fastutil.objects.ObjectArrayList; import org.openjdk.jol.info.ClassLayout; import java.util.Arrays; import java.util.List; import java.util.Optional; import static com.facebook.presto.operator.SyntheticAddress.decodePosition; import static com.facebook.presto.operator.SyntheticAddress.decodeSliceIndex; import static com.facebook.presto.operator.SyntheticAddress.encodeSyntheticAddress; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INSUFFICIENT_RESOURCES; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.gen.JoinCompiler.PagesHashStrategyFactory; import static com.facebook.presto.util.HashCollisionsEstimator.estimateNumberOfHashCollisions; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; import static io.airlift.slice.SizeOf.sizeOf; import static it.unimi.dsi.fastutil.HashCommon.arraySize; import static it.unimi.dsi.fastutil.HashCommon.murmurHash3; import static java.util.Objects.requireNonNull; // This implementation assumes arrays used in the hash are always a power of 2 public class MultiChannelGroupByHash implements GroupByHash { private static final int INSTANCE_SIZE = ClassLayout.parseClass(MultiChannelGroupByHash.class).instanceSize(); private static final float FILL_RATIO = 0.75f; private final List<Type> types; private final List<Type> hashTypes; private final int[] channels; private final PagesHashStrategy hashStrategy; private final List<ObjectArrayList<Block>> channelBuilders; private final Optional<Integer> inputHashChannel; private final HashGenerator hashGenerator; private final Optional<Integer> precomputedHashChannel; private final boolean processDictionary; private PageBuilder currentPageBuilder; private long completedPagesMemorySize; private int hashCapacity; private int maxFill; private int mask; private long[] groupAddressByHash; private int[] groupIdsByHash; private byte[] rawHashByHashPosition; private final LongBigArray groupAddressByGroupId; private int nextGroupId; private DictionaryLookBack dictionaryLookBack; private long hashCollisions; private double expectedHashCollisions; public MultiChannelGroupByHash( List<? extends Type> hashTypes, int[] hashChannels, Optional<Integer> inputHashChannel, int expectedSize, boolean processDictionary, JoinCompiler joinCompiler) { this.hashTypes = ImmutableList.copyOf(requireNonNull(hashTypes, "hashTypes is null")); requireNonNull(joinCompiler, "joinCompiler is null"); requireNonNull(hashChannels, "hashChannels is null"); checkArgument(hashTypes.size() == hashChannels.length, "hashTypes and hashChannels have different sizes"); checkArgument(expectedSize > 0, "expectedSize must be greater than zero"); this.inputHashChannel = requireNonNull(inputHashChannel, "inputHashChannel is null"); this.types = inputHashChannel.isPresent() ? ImmutableList.copyOf(Iterables.concat(hashTypes, ImmutableList.of(BIGINT))) : this.hashTypes; this.channels = hashChannels.clone(); this.hashGenerator = inputHashChannel.isPresent() ? new PrecomputedHashGenerator(inputHashChannel.get()) : new InterpretedHashGenerator(this.hashTypes, hashChannels); this.processDictionary = processDictionary; // For each hashed channel, create an appendable list to hold the blocks (builders). As we // add new values we append them to the existing block builder until it fills up and then // we add a new block builder to each list. ImmutableList.Builder<Integer> outputChannels = ImmutableList.builder(); ImmutableList.Builder<ObjectArrayList<Block>> channelBuilders = ImmutableList.builder(); for (int i = 0; i < hashChannels.length; i++) { outputChannels.add(i); channelBuilders.add(ObjectArrayList.wrap(new Block[1024], 0)); } if (inputHashChannel.isPresent()) { this.precomputedHashChannel = Optional.of(hashChannels.length); channelBuilders.add(ObjectArrayList.wrap(new Block[1024], 0)); } else { this.precomputedHashChannel = Optional.empty(); } this.channelBuilders = channelBuilders.build(); PagesHashStrategyFactory pagesHashStrategyFactory = joinCompiler.compilePagesHashStrategyFactory(this.types, outputChannels.build()); hashStrategy = pagesHashStrategyFactory.createPagesHashStrategy(this.channelBuilders, this.precomputedHashChannel); startNewPage(); // reserve memory for the arrays hashCapacity = arraySize(expectedSize, FILL_RATIO); maxFill = calculateMaxFill(hashCapacity); mask = hashCapacity - 1; groupAddressByHash = new long[hashCapacity]; Arrays.fill(groupAddressByHash, -1); rawHashByHashPosition = new byte[hashCapacity]; groupIdsByHash = new int[hashCapacity]; groupAddressByGroupId = new LongBigArray(); groupAddressByGroupId.ensureCapacity(maxFill); } @Override public long getRawHash(int groupId) { long address = groupAddressByGroupId.get(groupId); int blockIndex = decodeSliceIndex(address); int position = decodePosition(address); return hashStrategy.hashPosition(blockIndex, position); } @Override public long getEstimatedSize() { return INSTANCE_SIZE + (sizeOf(channelBuilders.get(0).elements()) * channelBuilders.size()) + completedPagesMemorySize + currentPageBuilder.getRetainedSizeInBytes() + sizeOf(groupAddressByHash) + sizeOf(groupIdsByHash) + groupAddressByGroupId.sizeOf() + sizeOf(rawHashByHashPosition); } @Override public long getHashCollisions() { return hashCollisions; } @Override public double getExpectedHashCollisions() { return expectedHashCollisions + estimateNumberOfHashCollisions(getGroupCount(), hashCapacity); } @Override public List<Type> getTypes() { return types; } @Override public int getGroupCount() { return nextGroupId; } @Override public void appendValuesTo(int groupId, PageBuilder pageBuilder, int outputChannelOffset) { long address = groupAddressByGroupId.get(groupId); int blockIndex = decodeSliceIndex(address); int position = decodePosition(address); hashStrategy.appendTo(blockIndex, position, pageBuilder, outputChannelOffset); } @Override public void addPage(Page page) { if (canProcessDictionary(page)) { addDictionaryPage(page); return; } // get the group id for each position int positionCount = page.getPositionCount(); for (int position = 0; position < positionCount; position++) { // get the group for the current row putIfAbsent(position, page); } } @Override public GroupByIdBlock getGroupIds(Page page) { int positionCount = page.getPositionCount(); // we know the exact size required for the block BlockBuilder blockBuilder = BIGINT.createFixedSizeBlockBuilder(positionCount); if (canProcessDictionary(page)) { Block groupIds = processDictionary(page); return new GroupByIdBlock(nextGroupId, groupIds); } // get the group id for each position for (int position = 0; position < positionCount; position++) { // get the group for the current row int groupId = putIfAbsent(position, page); // output the group id for this row BIGINT.writeLong(blockBuilder, groupId); } return new GroupByIdBlock(nextGroupId, blockBuilder.build()); } @Override public boolean contains(int position, Page page, int[] hashChannels) { long rawHash = hashStrategy.hashRow(position, page); int hashPosition = (int) getHashPosition(rawHash, mask); // look for a slot containing this key while (groupAddressByHash[hashPosition] != -1) { if (positionEqualsCurrentRow(groupAddressByHash[hashPosition], hashPosition, position, page, (byte) rawHash, hashChannels)) { // found an existing slot for this key return true; } // increment position and mask to handle wrap around hashPosition = (hashPosition + 1) & mask; } return false; } @Override public int putIfAbsent(int position, Page page) { long rawHash = hashGenerator.hashPosition(position, page); return putIfAbsent(position, page, rawHash); } private int putIfAbsent(int position, Page page, long rawHash) { int hashPosition = (int) getHashPosition(rawHash, mask); // look for an empty slot or a slot containing this key int groupId = -1; while (groupAddressByHash[hashPosition] != -1) { if (positionEqualsCurrentRow(groupAddressByHash[hashPosition], hashPosition, position, page, (byte) rawHash, channels)) { // found an existing slot for this key groupId = groupIdsByHash[hashPosition]; break; } // increment position and mask to handle wrap around hashPosition = (hashPosition + 1) & mask; hashCollisions++; } // did we find an existing group? if (groupId < 0) { groupId = addNewGroup(hashPosition, position, page, rawHash); } return groupId; } private int addNewGroup(int hashPosition, int position, Page page, long rawHash) { // add the row to the open page for (int i = 0; i < channels.length; i++) { int hashChannel = channels[i]; Type type = types.get(i); type.appendTo(page.getBlock(hashChannel), position, currentPageBuilder.getBlockBuilder(i)); } if (precomputedHashChannel.isPresent()) { BIGINT.writeLong(currentPageBuilder.getBlockBuilder(precomputedHashChannel.get()), rawHash); } currentPageBuilder.declarePosition(); int pageIndex = channelBuilders.get(0).size() - 1; int pagePosition = currentPageBuilder.getPositionCount() - 1; long address = encodeSyntheticAddress(pageIndex, pagePosition); // record group id in hash int groupId = nextGroupId++; groupAddressByHash[hashPosition] = address; rawHashByHashPosition[hashPosition] = (byte) rawHash; groupIdsByHash[hashPosition] = groupId; groupAddressByGroupId.set(groupId, address); // create new page builder if this page is full if (currentPageBuilder.isFull()) { startNewPage(); } // increase capacity, if necessary if (nextGroupId >= maxFill) { rehash(); } return groupId; } private void startNewPage() { if (currentPageBuilder != null) { completedPagesMemorySize += currentPageBuilder.getRetainedSizeInBytes(); currentPageBuilder = currentPageBuilder.newPageBuilderLike(); } else { currentPageBuilder = new PageBuilder(types); } for (int i = 0; i < types.size(); i++) { channelBuilders.get(i).add(currentPageBuilder.getBlockBuilder(i)); } } private void rehash() { expectedHashCollisions += estimateNumberOfHashCollisions(getGroupCount(), hashCapacity); long newCapacityLong = hashCapacity * 2L; if (newCapacityLong > Integer.MAX_VALUE) { throw new PrestoException(GENERIC_INSUFFICIENT_RESOURCES, "Size of hash table cannot exceed 1 billion entries"); } int newCapacity = (int) newCapacityLong; int newMask = newCapacity - 1; long[] newKey = new long[newCapacity]; byte[] rawHashes = new byte[newCapacity]; Arrays.fill(newKey, -1); int[] newValue = new int[newCapacity]; int oldIndex = 0; for (int groupId = 0; groupId < nextGroupId; groupId++) { // seek to the next used slot while (groupAddressByHash[oldIndex] == -1) { oldIndex++; } // get the address for this slot long address = groupAddressByHash[oldIndex]; long rawHash = hashPosition(address); // find an empty slot for the address int pos = (int) getHashPosition(rawHash, newMask); while (newKey[pos] != -1) { pos = (pos + 1) & newMask; hashCollisions++; } // record the mapping newKey[pos] = address; rawHashes[pos] = (byte) rawHash; newValue[pos] = groupIdsByHash[oldIndex]; oldIndex++; } this.mask = newMask; this.hashCapacity = newCapacity; this.maxFill = calculateMaxFill(newCapacity); this.groupAddressByHash = newKey; this.rawHashByHashPosition = rawHashes; this.groupIdsByHash = newValue; groupAddressByGroupId.ensureCapacity(maxFill); } private long hashPosition(long sliceAddress) { int sliceIndex = decodeSliceIndex(sliceAddress); int position = decodePosition(sliceAddress); if (precomputedHashChannel.isPresent()) { return getRawHash(sliceIndex, position); } return hashStrategy.hashPosition(sliceIndex, position); } private long getRawHash(int sliceIndex, int position) { return channelBuilders.get(precomputedHashChannel.get()).get(sliceIndex).getLong(position, 0); } private boolean positionEqualsCurrentRow(long address, int hashPosition, int position, Page page, byte rawHash, int[] hashChannels) { if (rawHashByHashPosition[hashPosition] != rawHash) { return false; } return hashStrategy.positionEqualsRow(decodeSliceIndex(address), decodePosition(address), position, page, hashChannels); } private static long getHashPosition(long rawHash, int mask) { return murmurHash3(rawHash) & mask; } private static int calculateMaxFill(int hashSize) { checkArgument(hashSize > 0, "hashSize must be greater than 0"); int maxFill = (int) Math.ceil(hashSize * FILL_RATIO); if (maxFill == hashSize) { maxFill--; } checkArgument(hashSize > maxFill, "hashSize must be larger than maxFill"); return maxFill; } private void addDictionaryPage(Page page) { verify(canProcessDictionary(page), "invalid call to addDictionaryPage"); DictionaryBlock dictionaryBlock = (DictionaryBlock) page.getBlock(channels[0]); updateDictionaryLookBack(dictionaryBlock.getDictionary()); Page dictionaryPage = createPageWithExtractedDictionary(page); for (int i = 0; i < page.getPositionCount(); i++) { int positionInDictionary = dictionaryBlock.getId(i); getGroupId(hashGenerator, dictionaryPage, positionInDictionary); } } private void updateDictionaryLookBack(Block dictionary) { if (dictionaryLookBack == null || dictionaryLookBack.getDictionary() != dictionary) { dictionaryLookBack = new DictionaryLookBack(dictionary); } } private Block processDictionary(Page page) { verify(canProcessDictionary(page), "invalid call to processDictionary"); DictionaryBlock dictionaryBlock = (DictionaryBlock) page.getBlock(channels[0]); updateDictionaryLookBack(dictionaryBlock.getDictionary()); Page dictionaryPage = createPageWithExtractedDictionary(page); BlockBuilder blockBuilder = BIGINT.createFixedSizeBlockBuilder(page.getPositionCount()); for (int i = 0; i < page.getPositionCount(); i++) { int positionInDictionary = dictionaryBlock.getId(i); int groupId = getGroupId(hashGenerator, dictionaryPage, positionInDictionary); BIGINT.writeLong(blockBuilder, groupId); } verify(blockBuilder.getPositionCount() == page.getPositionCount(), "invalid position count"); return blockBuilder.build(); } // For a page that contains DictionaryBlocks, create a new page in which // the dictionaries from the DictionaryBlocks are extracted into the corresponding channels // From Page(DictionaryBlock1, DictionaryBlock2) create new page with Page(dictionary1, dictionary2) private Page createPageWithExtractedDictionary(Page page) { Block[] blocks = new Block[page.getChannelCount()]; Block dictionary = ((DictionaryBlock) page.getBlock(channels[0])).getDictionary(); // extract data dictionary blocks[channels[0]] = dictionary; // extract hash dictionary if (inputHashChannel.isPresent()) { blocks[inputHashChannel.get()] = ((DictionaryBlock) page.getBlock(inputHashChannel.get())).getDictionary(); } return new Page(dictionary.getPositionCount(), blocks); } private boolean canProcessDictionary(Page page) { boolean processDictionary = this.processDictionary && channels.length == 1 && page.getBlock(channels[0]) instanceof DictionaryBlock; if (processDictionary && inputHashChannel.isPresent()) { Block inputHashBlock = page.getBlock(inputHashChannel.get()); DictionaryBlock inputDataBlock = (DictionaryBlock) page.getBlock(channels[0]); verify(inputHashBlock instanceof DictionaryBlock, "data channel is dictionary encoded but hash channel is not"); verify(((DictionaryBlock) inputHashBlock).getDictionarySourceId().equals(inputDataBlock.getDictionarySourceId()), "dictionarySourceIds of data block and hash block do not match"); } return processDictionary; } private int getGroupId(HashGenerator hashGenerator, Page page, int positionInDictionary) { if (dictionaryLookBack.isProcessed(positionInDictionary)) { return dictionaryLookBack.getGroupId(positionInDictionary); } int groupId = putIfAbsent(positionInDictionary, page, hashGenerator.hashPosition(positionInDictionary, page)); dictionaryLookBack.setProcessed(positionInDictionary, groupId); return groupId; } private static final class DictionaryLookBack { private final Block dictionary; private final int[] processed; public DictionaryLookBack(Block dictionary) { this.dictionary = dictionary; this.processed = new int[dictionary.getPositionCount()]; Arrays.fill(processed, -1); } public Block getDictionary() { return dictionary; } public int getGroupId(int position) { return processed[position]; } public boolean isProcessed(int position) { return processed[position] != -1; } public void setProcessed(int position, int groupId) { processed[position] = groupId; } } }