/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.SortOrder; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.google.common.collect.ImmutableList; import com.google.common.collect.MinMaxPriorityQueue; import com.google.common.collect.Ordering; import com.google.common.primitives.Ints; import io.airlift.units.DataSize; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; import static com.facebook.presto.operator.GroupByHash.createGroupByHash; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; public class TopNRowNumberOperator implements Operator { public static class TopNRowNumberOperatorFactory implements OperatorFactory { private final int operatorId; private final PlanNodeId planNodeId; private final List<Type> sourceTypes; private final List<Integer> outputChannels; private final List<Integer> partitionChannels; private final List<Type> partitionTypes; private final List<Integer> sortChannels; private final List<SortOrder> sortOrder; private final int maxRowCountPerPartition; private final boolean partial; private final Optional<Integer> hashChannel; private final int expectedPositions; private final List<Type> types; private final List<Type> sortTypes; private final boolean generateRowNumber; private boolean closed; private final JoinCompiler joinCompiler; public TopNRowNumberOperatorFactory( int operatorId, PlanNodeId planNodeId, List<? extends Type> sourceTypes, List<Integer> outputChannels, List<Integer> partitionChannels, List<? extends Type> partitionTypes, List<Integer> sortChannels, List<SortOrder> sortOrder, int maxRowCountPerPartition, boolean partial, Optional<Integer> hashChannel, int expectedPositions, JoinCompiler joinCompiler) { this.operatorId = operatorId; this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); this.sourceTypes = ImmutableList.copyOf(sourceTypes); this.outputChannels = ImmutableList.copyOf(requireNonNull(outputChannels, "outputChannels is null")); this.partitionChannels = ImmutableList.copyOf(requireNonNull(partitionChannels, "partitionChannels is null")); this.partitionTypes = ImmutableList.copyOf(requireNonNull(partitionTypes, "partitionTypes is null")); this.sortChannels = ImmutableList.copyOf(requireNonNull(sortChannels)); this.sortOrder = ImmutableList.copyOf(requireNonNull(sortOrder)); this.hashChannel = requireNonNull(hashChannel, "hashChannel is null"); this.partial = partial; checkArgument(maxRowCountPerPartition > 0, "maxRowCountPerPartition must be > 0"); this.maxRowCountPerPartition = maxRowCountPerPartition; checkArgument(expectedPositions > 0, "expectedPositions must be > 0"); this.generateRowNumber = !partial || !partitionChannels.isEmpty(); this.expectedPositions = expectedPositions; this.joinCompiler = requireNonNull(joinCompiler, "joinCompiler is null"); this.types = toTypes(sourceTypes, outputChannels, generateRowNumber); ImmutableList.Builder<Type> sortTypes = ImmutableList.builder(); for (int channel : sortChannels) { sortTypes.add(types.get(channel)); } this.sortTypes = sortTypes.build(); } @Override public List<Type> getTypes() { return types; } @Override public Operator createOperator(DriverContext driverContext) { checkState(!closed, "Factory is already closed"); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, TopNRowNumberOperator.class.getSimpleName()); return new TopNRowNumberOperator( operatorContext, sourceTypes, outputChannels, partitionChannels, partitionTypes, sortChannels, sortOrder, sortTypes, maxRowCountPerPartition, generateRowNumber, hashChannel, expectedPositions, joinCompiler); } @Override public void close() { closed = true; } @Override public OperatorFactory duplicate() { return new TopNRowNumberOperatorFactory(operatorId, planNodeId, sourceTypes, outputChannels, partitionChannels, partitionTypes, sortChannels, sortOrder, maxRowCountPerPartition, partial, hashChannel, expectedPositions, joinCompiler); } } private static final DataSize OVERHEAD_PER_VALUE = new DataSize(100, DataSize.Unit.BYTE); // for estimating in-memory size. This is a completely arbitrary number private final OperatorContext operatorContext; private boolean finishing; private final List<Type> types; private final int[] outputChannels; private final List<Integer> sortChannels; private final List<SortOrder> sortOrders; private final List<Type> sortTypes; private final boolean generateRowNumber; private final int maxRowCountPerPartition; private final Map<Long, PartitionBuilder> partitionRows; private Optional<FlushingPartition> flushingPartition; private final PageBuilder pageBuilder; private final Optional<GroupByHash> groupByHash; public TopNRowNumberOperator( OperatorContext operatorContext, List<? extends Type> sourceTypes, List<Integer> outputChannels, List<Integer> partitionChannels, List<Type> partitionTypes, List<Integer> sortChannels, List<SortOrder> sortOrders, List<Type> sortTypes, int maxRowCountPerPartition, boolean generateRowNumber, Optional<Integer> hashChannel, int expectedPositions, JoinCompiler joinCompiler) { this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); this.outputChannels = Ints.toArray(requireNonNull(outputChannels, "outputChannels is null")); this.sortChannels = requireNonNull(sortChannels, "sortChannels is null"); this.sortOrders = requireNonNull(sortOrders, "sortOrders is null"); this.sortTypes = requireNonNull(sortTypes, "sortTypes is null"); checkArgument(maxRowCountPerPartition > 0, "maxRowCountPerPartition must be > 0"); this.maxRowCountPerPartition = maxRowCountPerPartition; this.generateRowNumber = generateRowNumber; checkArgument(expectedPositions > 0, "expectedPositions must be > 0"); this.types = toTypes(sourceTypes, outputChannels, generateRowNumber); this.partitionRows = new HashMap<>(); if (partitionChannels.isEmpty()) { this.groupByHash = Optional.empty(); } else { this.groupByHash = Optional.of(createGroupByHash(operatorContext.getSession(), partitionTypes, Ints.toArray(partitionChannels), hashChannel, expectedPositions, joinCompiler)); } this.flushingPartition = Optional.empty(); this.pageBuilder = new PageBuilder(types); } @Override public OperatorContext getOperatorContext() { return operatorContext; } @Override public List<Type> getTypes() { return types; } @Override public void finish() { finishing = true; } @Override public boolean isFinished() { return finishing && isEmpty() && !isFlushing(); } @Override public boolean needsInput() { return !finishing && !isFlushing(); } @Override public void addInput(Page page) { checkState(!finishing, "Operator is already finishing"); requireNonNull(page, "page is null"); checkState(!isFlushing(), "Cannot add input with the operator is flushing data"); processPage(page); } @Override public Page getOutput() { if (finishing && !isFinished()) { return getPage(); } return null; } private void processPage(Page page) { Optional<GroupByIdBlock> partitionIds = Optional.empty(); if (groupByHash.isPresent()) { GroupByHash hash = groupByHash.get(); long groupByHashSize = hash.getEstimatedSize(); partitionIds = Optional.of(hash.getGroupIds(page)); operatorContext.reserveMemory(hash.getEstimatedSize() - groupByHashSize); } long sizeDelta = 0; Block[] blocks = page.getBlocks(); for (int position = 0; position < page.getPositionCount(); position++) { long partitionId = groupByHash.isPresent() ? partitionIds.get().getGroupId(position) : 0; if (!partitionRows.containsKey(partitionId)) { partitionRows.put(partitionId, new PartitionBuilder(sortTypes, sortChannels, sortOrders, maxRowCountPerPartition)); } PartitionBuilder partitionBuilder = partitionRows.get(partitionId); if (partitionBuilder.getRowCount() < maxRowCountPerPartition) { Block[] row = getSingleValueBlocks(page, position); sizeDelta += partitionBuilder.addRow(row); } else if (compare(position, blocks, partitionBuilder.peekLastRow()) < 0) { Block[] row = getSingleValueBlocks(page, position); sizeDelta += partitionBuilder.replaceRow(row); } } if (sizeDelta > 0) { operatorContext.reserveMemory(sizeDelta); } else { operatorContext.freeMemory(-sizeDelta); } } private int compare(int position, Block[] blocks, Block[] currentMax) { for (int i = 0; i < sortChannels.size(); i++) { Type type = sortTypes.get(i); int sortChannel = sortChannels.get(i); SortOrder sortOrder = sortOrders.get(i); Block block = blocks[sortChannel]; Block currentMaxValue = currentMax[sortChannel]; int compare = sortOrder.compareBlockValue(type, block, position, currentMaxValue, 0); if (compare != 0) { return compare; } } return 0; } private Page getPage() { if (!flushingPartition.isPresent()) { flushingPartition = getFlushingPartition(); } pageBuilder.reset(); long sizeDelta = 0; while (!pageBuilder.isFull() && flushingPartition.isPresent()) { FlushingPartition currentFlushingPartition = flushingPartition.get(); while (!pageBuilder.isFull() && currentFlushingPartition.hasNext()) { Block[] next = currentFlushingPartition.next(); sizeDelta += sizeOfRow(next); pageBuilder.declarePosition(); for (int i = 0; i < outputChannels.length; i++) { int channel = outputChannels[i]; Type type = types.get(i); type.appendTo(next[channel], 0, pageBuilder.getBlockBuilder(i)); } if (generateRowNumber) { BIGINT.writeLong(pageBuilder.getBlockBuilder(outputChannels.length), currentFlushingPartition.getRowNumber()); } } if (!currentFlushingPartition.hasNext()) { flushingPartition = getFlushingPartition(); } } if (pageBuilder.isEmpty()) { return null; } Page page = pageBuilder.build(); operatorContext.freeMemory(sizeDelta); return page; } private Optional<FlushingPartition> getFlushingPartition() { int maxPartitionSize = 0; PartitionBuilder chosenPartitionBuilder = null; long chosenPartitionId = -1; for (Map.Entry<Long, PartitionBuilder> entry : partitionRows.entrySet()) { if (entry.getValue().getRowCount() > maxPartitionSize) { chosenPartitionBuilder = entry.getValue(); maxPartitionSize = chosenPartitionBuilder.getRowCount(); chosenPartitionId = entry.getKey(); if (maxPartitionSize == maxRowCountPerPartition) { break; } } } if (chosenPartitionBuilder == null) { return Optional.empty(); } FlushingPartition flushingPartition = new FlushingPartition(chosenPartitionBuilder.build()); partitionRows.remove(chosenPartitionId); return Optional.of(flushingPartition); } public boolean isFlushing() { return flushingPartition.isPresent(); } public boolean isEmpty() { return partitionRows.isEmpty(); } private static Block[] getSingleValueBlocks(Page page, int position) { Block[] blocks = page.getBlocks(); Block[] row = new Block[blocks.length]; for (int i = 0; i < blocks.length; i++) { row[i] = blocks[i].getSingleValueBlock(position); } return row; } private static List<Type> toTypes(List<? extends Type> sourceTypes, List<Integer> outputChannels, boolean generateRowNumber) { ImmutableList.Builder<Type> types = ImmutableList.builder(); for (int channel : outputChannels) { types.add(sourceTypes.get(channel)); } if (generateRowNumber) { types.add(BIGINT); } return types.build(); } private static long sizeOfRow(Block[] row) { long size = OVERHEAD_PER_VALUE.toBytes(); for (Block value : row) { size += value.getRetainedSizeInBytes(); } return size; } private static class PartitionBuilder { private final MinMaxPriorityQueue<Block[]> candidateRows; private final int maxRowCountPerPartition; private PartitionBuilder(List<Type> sortTypes, List<Integer> sortChannels, List<SortOrder> sortOrders, int maxRowCountPerPartition) { this.maxRowCountPerPartition = maxRowCountPerPartition; Ordering<Block[]> comparator = Ordering.from(new RowComparator(sortTypes, sortChannels, sortOrders)); this.candidateRows = MinMaxPriorityQueue.orderedBy(comparator).maximumSize(maxRowCountPerPartition).create(); } private long replaceRow(Block[] row) { checkState(candidateRows.size() == maxRowCountPerPartition); Block[] previousRow = candidateRows.removeLast(); long sizeDelta = addRow(row); return sizeDelta - sizeOfRow(previousRow); } private long addRow(Block[] row) { checkState(candidateRows.size() < maxRowCountPerPartition); long sizeDelta = sizeOfRow(row); candidateRows.add(row); return sizeDelta; } private Iterator<Block[]> build() { ImmutableList.Builder<Block[]> sortedRows = ImmutableList.builder(); while (!candidateRows.isEmpty()) { sortedRows.add(candidateRows.poll()); } return sortedRows.build().iterator(); } private int getRowCount() { return candidateRows.size(); } private Block[] peekLastRow() { return candidateRows.peekLast(); } } private static class FlushingPartition implements Iterator<Block[]> { private final Iterator<Block[]> outputIterator; private int rowNumber; private FlushingPartition(Iterator<Block[]> outputIterator) { this.outputIterator = outputIterator; } @Override public boolean hasNext() { return outputIterator.hasNext(); } @Override public Block[] next() { rowNumber++; return outputIterator.next(); } @Override public void remove() { throw new UnsupportedOperationException(); } public int getRowNumber() { return rowNumber; } } }