/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.RunLengthEncodedBlock; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.google.common.collect.ImmutableList; import java.util.Arrays; import java.util.List; import java.util.Map; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; public class GroupIdOperator implements Operator { public static class GroupIdOperatorFactory implements OperatorFactory { private final int operatorId; private final PlanNodeId planNodeId; private final List<Type> outputTypes; private final List<Map<Integer, Integer>> groupingSetMappings; private boolean closed; public GroupIdOperatorFactory( int operatorId, PlanNodeId planNodeId, List<? extends Type> outputTypes, List<Map<Integer, Integer>> groupingSetMappings) { this.operatorId = operatorId; this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); this.outputTypes = ImmutableList.copyOf(requireNonNull(outputTypes)); this.groupingSetMappings = ImmutableList.copyOf(requireNonNull(groupingSetMappings)); } @Override public List<Type> getTypes() { return outputTypes; } @Override public Operator createOperator(DriverContext driverContext) { checkState(!closed, "Factory is already closed"); OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, GroupIdOperator.class.getSimpleName()); // create an int array for fast lookup of input columns for every grouping set int[][] groupingSetInputs = new int[groupingSetMappings.size()][outputTypes.size() - 1]; for (int i = 0; i < groupingSetMappings.size(); i++) { // -1 means the output column is null Arrays.fill(groupingSetInputs[i], -1); // anything else is an input column to copy for (int outputChannel : groupingSetMappings.get(i).keySet()) { groupingSetInputs[i][outputChannel] = groupingSetMappings.get(i).get(outputChannel); } } // it's easier to create null blocks for every output column even though we only null out some grouping column outputs Block[] nullBlocks = new Block[outputTypes.size()]; for (int i = 0; i < outputTypes.size(); i++) { nullBlocks[i] = outputTypes.get(i).createBlockBuilder(new BlockBuilderStatus(), 1) .appendNull() .build(); } // create groupid blocks for every group Block[] groupIdBlocks = new Block[groupingSetMappings.size()]; for (int i = 0; i < groupingSetMappings.size(); i++) { BlockBuilder builder = BIGINT.createBlockBuilder(new BlockBuilderStatus(), 1); BIGINT.writeLong(builder, i); groupIdBlocks[i] = builder.build(); } return new GroupIdOperator(operatorContext, outputTypes, groupingSetInputs, nullBlocks, groupIdBlocks); } @Override public void close() { closed = true; } @Override public OperatorFactory duplicate() { return new GroupIdOperatorFactory(operatorId, planNodeId, outputTypes, groupingSetMappings); } } private final OperatorContext operatorContext; private final List<Type> types; private final int[][] groupingSetInputs; private final Block[] nullBlocks; private final Block[] groupIdBlocks; private Page currentPage = null; private int currentGroupingSet = 0; private boolean finishing; public GroupIdOperator( OperatorContext operatorContext, List<Type> types, int[][] groupingSetInputs, Block[] nullBlocks, Block[] groupIdBlocks) { this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); this.types = ImmutableList.copyOf(requireNonNull(types, "types is null")); this.groupingSetInputs = requireNonNull(groupingSetInputs, "groupingSetInputs is null"); this.nullBlocks = requireNonNull(nullBlocks, "nullBlocks is null"); this.groupIdBlocks = requireNonNull(groupIdBlocks, "groupIdBlocks is null"); } @Override public OperatorContext getOperatorContext() { return operatorContext; } @Override public List<Type> getTypes() { return types; } @Override public void finish() { finishing = true; } @Override public boolean isFinished() { return finishing && currentPage == null; } @Override public boolean needsInput() { return !finishing && currentPage == null; } @Override public void addInput(Page page) { checkState(!finishing, "Operator is already finishing"); checkState(currentPage == null, "currentPage must be null to add a new page"); currentPage = requireNonNull(page, "page is null"); } @Override public Page getOutput() { if (currentPage == null) { return null; } return generateNextPage(); } private Page generateNextPage() { // generate 'n' pages for every input page, where n is the number of grouping sets Block[] outputBlocks = new Block[types.size()]; for (int i = 0; i < groupingSetInputs[currentGroupingSet].length; i++) { if (groupingSetInputs[currentGroupingSet][i] == -1) { outputBlocks[i] = new RunLengthEncodedBlock(nullBlocks[i], currentPage.getPositionCount()); } else { outputBlocks[i] = currentPage.getBlock(groupingSetInputs[currentGroupingSet][i]); } } outputBlocks[outputBlocks.length - 1] = new RunLengthEncodedBlock(groupIdBlocks[currentGroupingSet], currentPage.getPositionCount()); currentGroupingSet = (currentGroupingSet + 1) % groupingSetInputs.length; Page outputPage = new Page(currentPage.getPositionCount(), outputBlocks); if (currentGroupingSet == 0) { currentPage = null; } return outputPage; } }