/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator; import com.facebook.presto.operator.window.FramedWindowFunction; import com.facebook.presto.operator.window.WindowPartition; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.SortOrder; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.primitives.Ints; import java.util.List; import java.util.Optional; import java.util.stream.Stream; import static com.facebook.presto.spi.block.SortOrder.ASC_NULLS_LAST; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkPositionIndex; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.concat; import static java.util.Collections.nCopies; import static java.util.Objects.requireNonNull; public class WindowOperator implements Operator { public static class WindowOperatorFactory implements OperatorFactory { private final int operatorId; private final PlanNodeId planNodeId; private final List<Type> sourceTypes; private final List<Integer> outputChannels; private final List<WindowFunctionDefinition> windowFunctionDefinitions; private final List<Integer> partitionChannels; private final List<Integer> preGroupedChannels; private final List<Integer> sortChannels; private final List<SortOrder> sortOrder; private final int preSortedChannelPrefix; private final int expectedPositions; private final List<Type> types; private boolean closed; private final PagesIndex.Factory pagesIndexFactory; public WindowOperatorFactory( int operatorId, PlanNodeId planNodeId, List<? extends Type> sourceTypes, List<Integer> outputChannels, List<WindowFunctionDefinition> windowFunctionDefinitions, List<Integer> partitionChannels, List<Integer> preGroupedChannels, List<Integer> sortChannels, List<SortOrder> sortOrder, int preSortedChannelPrefix, int expectedPositions, PagesIndex.Factory pagesIndexFactory) { requireNonNull(sourceTypes, "sourceTypes is null"); requireNonNull(planNodeId, "planNodeId is null"); requireNonNull(outputChannels, "outputChannels is null"); requireNonNull(windowFunctionDefinitions, "windowFunctionDefinitions is null"); requireNonNull(partitionChannels, "partitionChannels is null"); requireNonNull(preGroupedChannels, "preGroupedChannels is null"); checkArgument(partitionChannels.containsAll(preGroupedChannels), "preGroupedChannels must be a subset of partitionChannels"); requireNonNull(sortChannels, "sortChannels is null"); requireNonNull(sortOrder, "sortOrder is null"); requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); checkArgument(sortChannels.size() == sortOrder.size(), "Must have same number of sort channels as sort orders"); checkArgument(preSortedChannelPrefix <= sortChannels.size(), "Cannot have more pre-sorted channels than specified sorted channels"); checkArgument(preSortedChannelPrefix == 0 || ImmutableSet.copyOf(preGroupedChannels).equals(ImmutableSet.copyOf(partitionChannels)), "preSortedChannelPrefix can only be greater than zero if all partition channels are pre-grouped"); this.pagesIndexFactory = pagesIndexFactory; this.operatorId = operatorId; this.planNodeId = planNodeId; this.sourceTypes = ImmutableList.copyOf(sourceTypes); this.outputChannels = ImmutableList.copyOf(outputChannels); this.windowFunctionDefinitions = ImmutableList.copyOf(windowFunctionDefinitions); this.partitionChannels = ImmutableList.copyOf(partitionChannels); this.preGroupedChannels = ImmutableList.copyOf(preGroupedChannels); this.sortChannels = ImmutableList.copyOf(sortChannels); this.sortOrder = ImmutableList.copyOf(sortOrder); this.preSortedChannelPrefix = preSortedChannelPrefix; this.expectedPositions = expectedPositions; this.types = Stream.concat( outputChannels.stream() .map(sourceTypes::get), windowFunctionDefinitions.stream() .map(WindowFunctionDefinition::getType)) .collect(toImmutableList()); } @Override public List<Type> getTypes() { return types; } @Override public Operator createOperator(DriverContext driverContext) { checkState(!closed, "Factory is already closed"); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, WindowOperator.class.getSimpleName()); return new WindowOperator( operatorContext, sourceTypes, outputChannels, windowFunctionDefinitions, partitionChannels, preGroupedChannels, sortChannels, sortOrder, preSortedChannelPrefix, expectedPositions, pagesIndexFactory); } @Override public void close() { closed = true; } @Override public OperatorFactory duplicate() { return new WindowOperatorFactory( operatorId, planNodeId, sourceTypes, outputChannels, windowFunctionDefinitions, partitionChannels, preGroupedChannels, sortChannels, sortOrder, preSortedChannelPrefix, expectedPositions, pagesIndexFactory); } } private enum State { NEEDS_INPUT, HAS_OUTPUT, FINISHING, FINISHED } private final OperatorContext operatorContext; private final int[] outputChannels; private final List<FramedWindowFunction> windowFunctions; private final List<Integer> orderChannels; private final List<SortOrder> ordering; private final List<Type> types; private final int[] preGroupedChannels; private final PagesHashStrategy preGroupedPartitionHashStrategy; private final PagesHashStrategy unGroupedPartitionHashStrategy; private final PagesHashStrategy preSortedPartitionHashStrategy; private final PagesHashStrategy peerGroupHashStrategy; private final PagesIndex pagesIndex; private final PageBuilder pageBuilder; private State state = State.NEEDS_INPUT; private WindowPartition partition; private Page pendingInput; public WindowOperator( OperatorContext operatorContext, List<Type> sourceTypes, List<Integer> outputChannels, List<WindowFunctionDefinition> windowFunctionDefinitions, List<Integer> partitionChannels, List<Integer> preGroupedChannels, List<Integer> sortChannels, List<SortOrder> sortOrder, int preSortedChannelPrefix, int expectedPositions, PagesIndex.Factory pagesIndexFactory) { requireNonNull(operatorContext, "operatorContext is null"); requireNonNull(outputChannels, "outputChannels is null"); requireNonNull(windowFunctionDefinitions, "windowFunctionDefinitions is null"); requireNonNull(partitionChannels, "partitionChannels is null"); requireNonNull(preGroupedChannels, "preGroupedChannels is null"); checkArgument(partitionChannels.containsAll(preGroupedChannels), "preGroupedChannels must be a subset of partitionChannels"); requireNonNull(sortChannels, "sortChannels is null"); requireNonNull(sortOrder, "sortOrder is null"); requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); checkArgument(sortChannels.size() == sortOrder.size(), "Must have same number of sort channels as sort orders"); checkArgument(preSortedChannelPrefix <= sortChannels.size(), "Cannot have more pre-sorted channels than specified sorted channels"); checkArgument(preSortedChannelPrefix == 0 || ImmutableSet.copyOf(preGroupedChannels).equals(ImmutableSet.copyOf(partitionChannels)), "preSortedChannelPrefix can only be greater than zero if all partition channels are pre-grouped"); this.operatorContext = operatorContext; this.outputChannels = Ints.toArray(outputChannels); this.windowFunctions = windowFunctionDefinitions.stream() .map(functionDefinition -> new FramedWindowFunction(functionDefinition.createWindowFunction(), functionDefinition.getFrameInfo())) .collect(toImmutableList()); this.types = Stream.concat( outputChannels.stream() .map(sourceTypes::get), windowFunctionDefinitions.stream() .map(WindowFunctionDefinition::getType)) .collect(toImmutableList()); this.pagesIndex = pagesIndexFactory.newPagesIndex(sourceTypes, expectedPositions); this.preGroupedChannels = Ints.toArray(preGroupedChannels); this.preGroupedPartitionHashStrategy = pagesIndex.createPagesHashStrategy(preGroupedChannels, Optional.empty()); List<Integer> unGroupedPartitionChannels = partitionChannels.stream() .filter(channel -> !preGroupedChannels.contains(channel)) .collect(toImmutableList()); this.unGroupedPartitionHashStrategy = pagesIndex.createPagesHashStrategy(unGroupedPartitionChannels, Optional.empty()); List<Integer> preSortedChannels = sortChannels.stream() .limit(preSortedChannelPrefix) .collect(toImmutableList()); this.preSortedPartitionHashStrategy = pagesIndex.createPagesHashStrategy(preSortedChannels, Optional.empty()); this.peerGroupHashStrategy = pagesIndex.createPagesHashStrategy(sortChannels, Optional.empty()); this.pageBuilder = new PageBuilder(this.types); if (preSortedChannelPrefix > 0) { // This already implies that set(preGroupedChannels) == set(partitionChannels) (enforced with checkArgument) this.orderChannels = ImmutableList.copyOf(Iterables.skip(sortChannels, preSortedChannelPrefix)); this.ordering = ImmutableList.copyOf(Iterables.skip(sortOrder, preSortedChannelPrefix)); } else { // Otherwise, we need to sort by the unGroupedPartitionChannels and all original sort channels this.orderChannels = ImmutableList.copyOf(concat(unGroupedPartitionChannels, sortChannels)); this.ordering = ImmutableList.copyOf(concat(nCopies(unGroupedPartitionChannels.size(), ASC_NULLS_LAST), sortOrder)); } } @Override public OperatorContext getOperatorContext() { return operatorContext; } @Override public List<Type> getTypes() { return types; } @Override public void finish() { if (state == State.FINISHING || state == State.FINISHED) { return; } if (state == State.NEEDS_INPUT) { // Since was waiting for more input, prepare what we have for output since we will not be getting any more input sortPagesIndexIfNecessary(); } state = State.FINISHING; } @Override public boolean isFinished() { return state == State.FINISHED; } @Override public boolean needsInput() { return state == State.NEEDS_INPUT; } @Override public void addInput(Page page) { checkState(state == State.NEEDS_INPUT, "Operator can not take input at this time"); requireNonNull(page, "page is null"); checkState(pendingInput == null, "Operator already has pending input"); if (page.getPositionCount() == 0) { return; } pendingInput = page; if (processPendingInput()) { state = State.HAS_OUTPUT; } operatorContext.setMemoryReservation(pagesIndex.getEstimatedSize().toBytes()); } /** * @return true if a full group has been buffered after processing the pendingInput, false otherwise */ private boolean processPendingInput() { checkState(pendingInput != null); pendingInput = updatePagesIndex(pendingInput); // If we have unused input or are finishing, then we have buffered a full group if (pendingInput != null || state == State.FINISHING) { sortPagesIndexIfNecessary(); return true; } else { return false; } } /** * @return the unused section of the page, or null if fully applied. * pagesIndex guaranteed to have at least one row after this method returns */ private Page updatePagesIndex(Page page) { checkArgument(page.getPositionCount() > 0); // TODO: Fix pagesHashStrategy to allow specifying channels for comparison, it currently requires us to rearrange the right side blocks in consecutive channel order Page preGroupedPage = rearrangePage(page, preGroupedChannels); if (pagesIndex.getPositionCount() == 0 || pagesIndex.positionEqualsRow(preGroupedPartitionHashStrategy, 0, 0, preGroupedPage)) { // Find the position where the pre-grouped columns change int groupEnd = findGroupEnd(preGroupedPage, preGroupedPartitionHashStrategy, 0); // Add the section of the page that contains values for the current group pagesIndex.addPage(page.getRegion(0, groupEnd)); if (page.getPositionCount() - groupEnd > 0) { // Save the remaining page, which may contain multiple partitions return page.getRegion(groupEnd, page.getPositionCount() - groupEnd); } else { // Page fully consumed return null; } } else { // We had previous results buffered, but the new page starts with new group values return page; } } private static Page rearrangePage(Page page, int[] channels) { Block[] newBlocks = new Block[channels.length]; for (int i = 0; i < channels.length; i++) { newBlocks[i] = page.getBlock(channels[i]); } return new Page(page.getPositionCount(), newBlocks); } @Override public Page getOutput() { if (state == State.NEEDS_INPUT || state == State.FINISHED) { return null; } Page page = extractOutput(); operatorContext.setMemoryReservation(pagesIndex.getEstimatedSize().toBytes()); return page; } private Page extractOutput() { // INVARIANT: pagesIndex contains the full grouped & sorted data for one or more partitions // Iterate through the positions sequentially until we have one full page while (!pageBuilder.isFull()) { if (partition == null || !partition.hasNext()) { int partitionStart = partition == null ? 0 : partition.getPartitionEnd(); if (partitionStart >= pagesIndex.getPositionCount()) { // Finished all of the partitions in the current pagesIndex partition = null; pagesIndex.clear(); // Try to extract more partitions from the pendingInput if (pendingInput != null && processPendingInput()) { partitionStart = 0; } else if (state == State.FINISHING) { state = State.FINISHED; // Output the remaining page if we have anything buffered if (!pageBuilder.isEmpty()) { Page page = pageBuilder.build(); pageBuilder.reset(); return page; } return null; } else { state = State.NEEDS_INPUT; return null; } } int partitionEnd = findGroupEnd(pagesIndex, unGroupedPartitionHashStrategy, partitionStart); partition = new WindowPartition(pagesIndex, partitionStart, partitionEnd, outputChannels, windowFunctions, peerGroupHashStrategy); } partition.processNextRow(pageBuilder); } Page page = pageBuilder.build(); pageBuilder.reset(); return page; } private void sortPagesIndexIfNecessary() { if (pagesIndex.getPositionCount() > 1 && !orderChannels.isEmpty()) { int startPosition = 0; while (startPosition < pagesIndex.getPositionCount()) { int endPosition = findGroupEnd(pagesIndex, preSortedPartitionHashStrategy, startPosition); pagesIndex.sort(orderChannels, ordering, startPosition, endPosition); startPosition = endPosition; } } } // Assumes input grouped on relevant pagesHashStrategy columns private static int findGroupEnd(Page page, PagesHashStrategy pagesHashStrategy, int startPosition) { checkArgument(page.getPositionCount() > 0, "Must have at least one position"); checkPositionIndex(startPosition, page.getPositionCount(), "startPosition out of bounds"); // Short circuit if the whole page has the same value if (pagesHashStrategy.rowEqualsRow(startPosition, page, page.getPositionCount() - 1, page)) { return page.getPositionCount(); } // TODO: do position binary search int endPosition = startPosition + 1; while (endPosition < page.getPositionCount() && pagesHashStrategy.rowEqualsRow(endPosition - 1, page, endPosition, page)) { endPosition++; } return endPosition; } // Assumes input grouped on relevant pagesHashStrategy columns private static int findGroupEnd(PagesIndex pagesIndex, PagesHashStrategy pagesHashStrategy, int startPosition) { checkArgument(pagesIndex.getPositionCount() > 0, "Must have at least one position"); checkPositionIndex(startPosition, pagesIndex.getPositionCount(), "startPosition out of bounds"); // Short circuit if the whole page has the same value if (pagesIndex.positionEqualsPosition(pagesHashStrategy, startPosition, pagesIndex.getPositionCount() - 1)) { return pagesIndex.getPositionCount(); } // TODO: do position binary search int endPosition = startPosition + 1; while ((endPosition < pagesIndex.getPositionCount()) && pagesIndex.positionEqualsPosition(pagesHashStrategy, endPosition - 1, endPosition)) { endPosition++; } return endPosition; } }