/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to you under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.hive.ql.optimizer.calcite.rules; import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.util.ImmutableBitSet; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelOptUtil; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveGroupingID; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; /** * Planner rule that recognizes a {@link HiveAggregate} * on top of a {@link HiveProject} and if possible * aggregate through the project or removes the project. * * <p>This is only possible when the grouping expressions and arguments to * the aggregate functions are field references (i.e. not expressions). * * <p>In some cases, this rule has the effect of trimming: the aggregate will * use fewer columns than the project did. */ public class HiveAggregateProjectMergeRule extends RelOptRule { public static final HiveAggregateProjectMergeRule INSTANCE = new HiveAggregateProjectMergeRule(); /** Private constructor. */ private HiveAggregateProjectMergeRule() { super( operand(HiveAggregate.class, operand(HiveProject.class, any()))); } @Override public boolean matches(RelOptRuleCall call) { final Aggregate aggregate = call.rel(0); // Rule cannot be applied if there are GroupingId because it will change the // value as the position will be changed. for (AggregateCall aggCall : aggregate.getAggCallList()) { if (aggCall.getAggregation().equals(HiveGroupingID.INSTANCE)) { return false; } } return super.matches(call); } @Override public void onMatch(RelOptRuleCall call) { final HiveAggregate aggregate = call.rel(0); final HiveProject project = call.rel(1); RelNode x = apply(aggregate, project); if (x != null) { call.transformTo(x); } } public static RelNode apply(HiveAggregate aggregate, HiveProject project) { final List<Integer> newKeys = Lists.newArrayList(); final Map<Integer, Integer> map = new HashMap<>(); for (int key : aggregate.getGroupSet()) { final RexNode rex = project.getProjects().get(key); if (rex instanceof RexInputRef) { final int newKey = ((RexInputRef) rex).getIndex(); newKeys.add(newKey); map.put(key, newKey); } else { // Cannot handle "GROUP BY expression" return null; } } final ImmutableBitSet newGroupSet = aggregate.getGroupSet().permute(map); ImmutableList<ImmutableBitSet> newGroupingSets = null; if (aggregate.indicator) { newGroupingSets = ImmutableBitSet.ORDERING.immutableSortedCopy( ImmutableBitSet.permute(aggregate.getGroupSets(), map)); } final ImmutableList.Builder<AggregateCall> aggCalls = ImmutableList.builder(); for (AggregateCall aggregateCall : aggregate.getAggCallList()) { final ImmutableList.Builder<Integer> newArgs = ImmutableList.builder(); for (int arg : aggregateCall.getArgList()) { final RexNode rex = project.getProjects().get(arg); if (rex instanceof RexInputRef) { newArgs.add(((RexInputRef) rex).getIndex()); } else { // Cannot handle "AGG(expression)" return null; } } final int newFilterArg; if (aggregateCall.filterArg >= 0) { final RexNode rex = project.getProjects().get(aggregateCall.filterArg); if (!(rex instanceof RexInputRef)) { return null; } newFilterArg = ((RexInputRef) rex).getIndex(); } else { newFilterArg = -1; } aggCalls.add(aggregateCall.copy(newArgs.build(), newFilterArg)); } final Aggregate newAggregate = aggregate.copy(aggregate.getTraitSet(), project.getInput(), aggregate.indicator, newGroupSet, newGroupingSets, aggCalls.build()); // Add a project if the group set is not in the same order or // contains duplicates. RelNode rel = newAggregate; if (!newKeys.equals(newGroupSet.asList())) { final List<Integer> posList = Lists.newArrayList(); for (int newKey : newKeys) { posList.add(newGroupSet.indexOf(newKey)); } if (aggregate.indicator) { for (int newKey : newKeys) { posList.add(aggregate.getGroupCount() + newGroupSet.indexOf(newKey)); } } for (int i = newAggregate.getGroupCount() + newAggregate.getIndicatorCount(); i < newAggregate.getRowType().getFieldCount(); i++) { posList.add(i); } rel = HiveRelOptUtil.createProject( HiveRelFactories.HIVE_BUILDER.create(aggregate.getCluster(), null), rel, posList); } return rel; } } // End AggregateProjectMergeRule.java