/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; import com.facebook.presto.type.TypeUtils; import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Optional; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; public class SimplePagesHashStrategy implements PagesHashStrategy { private final List<Type> types; private final List<Integer> outputChannels; private final List<List<Block>> channels; private final List<Integer> hashChannels; private final List<Block> precomputedHashChannel; private final Optional<SortExpression> sortChannel; public SimplePagesHashStrategy( List<Type> types, List<Integer> outputChannels, List<List<Block>> channels, List<Integer> hashChannels, Optional<Integer> precomputedHashChannel, Optional<SortExpression> sortChannel) { this.types = ImmutableList.copyOf(requireNonNull(types, "types is null")); this.outputChannels = ImmutableList.copyOf(requireNonNull(outputChannels, "outputChannels is null")); this.channels = ImmutableList.copyOf(requireNonNull(channels, "channels is null")); checkArgument(types.size() == channels.size(), "Expected types and channels to be the same length"); this.hashChannels = ImmutableList.copyOf(requireNonNull(hashChannels, "hashChannels is null")); if (precomputedHashChannel.isPresent()) { this.precomputedHashChannel = channels.get(precomputedHashChannel.get()); } else { this.precomputedHashChannel = null; } this.sortChannel = requireNonNull(sortChannel, "sortChannel is null"); } @Override public int getChannelCount() { return outputChannels.size(); } @Override public long getSizeInBytes() { return channels.stream() .flatMap(List::stream) .mapToLong(Block::getRetainedSizeInBytes) .sum(); } @Override public void appendTo(int blockIndex, int position, PageBuilder pageBuilder, int outputChannelOffset) { for (int outputIndex : outputChannels) { Type type = types.get(outputIndex); List<Block> channel = channels.get(outputIndex); Block block = channel.get(blockIndex); type.appendTo(block, position, pageBuilder.getBlockBuilder(outputChannelOffset)); outputChannelOffset++; } } @Override public long hashPosition(int blockIndex, int position) { if (precomputedHashChannel != null) { return BIGINT.getLong(precomputedHashChannel.get(blockIndex), position); } long result = 0; for (int hashChannel : hashChannels) { Type type = types.get(hashChannel); Block block = channels.get(hashChannel).get(blockIndex); result = result * 31 + TypeUtils.hashPosition(type, block, position); } return result; } @Override public long hashRow(int position, Page page) { long result = 0; for (int i = 0; i < hashChannels.size(); i++) { int hashChannel = hashChannels.get(i); Type type = types.get(hashChannel); Block block = page.getBlock(i); result = result * 31 + TypeUtils.hashPosition(type, block, position); } return result; } @Override public boolean rowEqualsRow(int leftPosition, Page leftPage, int rightPosition, Page rightPage) { for (int i = 0; i < hashChannels.size(); i++) { int hashChannel = hashChannels.get(i); Type type = types.get(hashChannel); Block leftBlock = leftPage.getBlock(i); Block rightBlock = rightPage.getBlock(i); if (!TypeUtils.positionEqualsPosition(type, leftBlock, leftPosition, rightBlock, rightPosition)) { return false; } } return true; } @Override public boolean positionEqualsRow(int leftBlockIndex, int leftPosition, int rightPosition, Page rightPage) { for (int i = 0; i < hashChannels.size(); i++) { int hashChannel = hashChannels.get(i); Type type = types.get(hashChannel); Block leftBlock = channels.get(hashChannel).get(leftBlockIndex); Block rightBlock = rightPage.getBlock(i); if (!TypeUtils.positionEqualsPosition(type, leftBlock, leftPosition, rightBlock, rightPosition)) { return false; } } return true; } @Override public boolean positionEqualsRowIgnoreNulls(int leftBlockIndex, int leftPosition, int rightPosition, Page rightPage) { for (int i = 0; i < hashChannels.size(); i++) { int hashChannel = hashChannels.get(i); Type type = types.get(hashChannel); Block leftBlock = channels.get(hashChannel).get(leftBlockIndex); Block rightBlock = rightPage.getBlock(i); if (!type.equalTo(leftBlock, leftPosition, rightBlock, rightPosition)) { return false; } } return true; } @Override public boolean positionEqualsRow(int leftBlockIndex, int leftPosition, int rightPosition, Page page, int[] rightHashChannels) { for (int i = 0; i < hashChannels.size(); i++) { int hashChannel = hashChannels.get(i); Type type = types.get(hashChannel); Block leftBlock = channels.get(hashChannel).get(leftBlockIndex); Block rightBlock = page.getBlock(rightHashChannels[i]); if (!TypeUtils.positionEqualsPosition(type, leftBlock, leftPosition, rightBlock, rightPosition)) { return false; } } return true; } @Override public boolean positionEqualsPosition(int leftBlockIndex, int leftPosition, int rightBlockIndex, int rightPosition) { for (int hashChannel : hashChannels) { Type type = types.get(hashChannel); List<Block> channel = channels.get(hashChannel); Block leftBlock = channel.get(leftBlockIndex); Block rightBlock = channel.get(rightBlockIndex); if (!TypeUtils.positionEqualsPosition(type, leftBlock, leftPosition, rightBlock, rightPosition)) { return false; } } return true; } @Override public boolean positionEqualsPositionIgnoreNulls(int leftBlockIndex, int leftPosition, int rightBlockIndex, int rightPosition) { for (int hashChannel : hashChannels) { Type type = types.get(hashChannel); List<Block> channel = channels.get(hashChannel); Block leftBlock = channel.get(leftBlockIndex); Block rightBlock = channel.get(rightBlockIndex); if (!type.equalTo(leftBlock, leftPosition, rightBlock, rightPosition)) { return false; } } return true; } @Override public boolean isPositionNull(int blockIndex, int blockPosition) { for (int hashChannel : hashChannels) { List<Block> channel = channels.get(hashChannel); Block block = channel.get(blockIndex); if (block.isNull(blockPosition)) { return true; } } return false; } @Override public int compare(int leftBlockIndex, int leftBlockPosition, int rightBlockIndex, int rightBlockPosition) { if (!sortChannel.isPresent()) { throw new UnsupportedOperationException(); } int channel = sortChannel.get().getChannel(); Block leftBlock = channels.get(channel).get(leftBlockIndex); Block rightBlock = channels.get(channel).get(rightBlockIndex); return types.get(channel).compareTo(leftBlock, leftBlockPosition, rightBlock, rightBlockPosition); } }