/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.metadata.RelColumnOrigin;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;
import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveGroupingID;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveRelNode;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.TypeConverter;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.base.Function;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.math.IntMath;
/**
* Planner rule that expands distinct aggregates
* (such as {@code COUNT(DISTINCT x)}) from a
* {@link org.apache.calcite.rel.core.Aggregate}.
*
* <p>How this is done depends upon the arguments to the function. If all
* functions have the same argument
* (e.g. {@code COUNT(DISTINCT x), SUM(DISTINCT x)} both have the argument
* {@code x}) then one extra {@link org.apache.calcite.rel.core.Aggregate} is
* sufficient.
*
* <p>If there are multiple arguments
* (e.g. {@code COUNT(DISTINCT x), COUNT(DISTINCT y)})
* the rule creates separate {@code Aggregate}s and combines using a
* {@link org.apache.calcite.rel.core.Join}.
*/
// Stripped down version of org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule
// This is adapted for Hive, but should eventually be deleted from Hive and make use of above.
public final class HiveExpandDistinctAggregatesRule extends RelOptRule {
//~ Static fields/initializers ---------------------------------------------
/** The default instance of the rule; operates only on logical expressions. */
public static final HiveExpandDistinctAggregatesRule INSTANCE =
new HiveExpandDistinctAggregatesRule(HiveAggregate.class,
HiveRelFactories.HIVE_PROJECT_FACTORY);
private static RelFactories.ProjectFactory projFactory;
protected static final Logger LOG = LoggerFactory.getLogger(HiveExpandDistinctAggregatesRule.class);
//~ Constructors -----------------------------------------------------------
public HiveExpandDistinctAggregatesRule(
Class<? extends Aggregate> clazz,RelFactories.ProjectFactory projectFactory) {
super(operand(clazz, any()));
projFactory = projectFactory;
}
RelOptCluster cluster = null;
RexBuilder rexBuilder = null;
//~ Methods ----------------------------------------------------------------
@Override
public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
int numCountDistinct = getNumCountDistinctCall(aggregate);
if (numCountDistinct == 0) {
return;
}
// Find all of the agg expressions. We use a List (for all count(distinct))
// as well as a Set (for all others) to ensure determinism.
int nonDistinctCount = 0;
List<List<Integer>> argListList = new ArrayList<List<Integer>>();
Set<List<Integer>> argListSets = new LinkedHashSet<List<Integer>>();
Set<Integer> positions = new HashSet<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
if (!aggCall.isDistinct()) {
++nonDistinctCount;
continue;
}
ArrayList<Integer> argList = new ArrayList<Integer>();
for (Integer arg : aggCall.getArgList()) {
argList.add(arg);
positions.add(arg);
}
// Aggr checks for sorted argList.
argListList.add(argList);
argListSets.add(argList);
}
Util.permAssert(argListSets.size() > 0, "containsDistinctCall lied");
if (numCountDistinct > 1 && numCountDistinct == aggregate.getAggCallList().size()
&& aggregate.getGroupSet().isEmpty()) {
LOG.debug("Trigger countDistinct rewrite. numCountDistinct is " + numCountDistinct);
// now positions contains all the distinct positions, i.e., $5, $4, $6
// we need to first sort them as group by set
// and then get their position later, i.e., $4->1, $5->2, $6->3
cluster = aggregate.getCluster();
rexBuilder = cluster.getRexBuilder();
RelNode converted = null;
List<Integer> sourceOfForCountDistinct = new ArrayList<>();
sourceOfForCountDistinct.addAll(positions);
Collections.sort(sourceOfForCountDistinct);
try {
converted = convert(aggregate, argListList, sourceOfForCountDistinct);
} catch (CalciteSemanticException e) {
LOG.debug(e.toString());
throw new RuntimeException(e);
}
call.transformTo(converted);
return;
}
// If all of the agg expressions are distinct and have the same
// arguments then we can use a more efficient form.
if ((nonDistinctCount == 0) && (argListSets.size() == 1)) {
for (Integer arg : argListSets.iterator().next()) {
Set<RelColumnOrigin> colOrigs = RelMetadataQuery.instance().getColumnOrigins(aggregate, arg);
if (null != colOrigs) {
for (RelColumnOrigin colOrig : colOrigs) {
RelOptHiveTable hiveTbl = (RelOptHiveTable)colOrig.getOriginTable();
if(hiveTbl.getPartColInfoMap().containsKey(colOrig.getOriginColumnOrdinal())) {
// Encountered partitioning column, this will be better handled by MetadataOnly optimizer.
return;
}
}
}
}
RelNode converted =
convertMonopole(
aggregate,
argListSets.iterator().next());
call.transformTo(converted);
return;
}
}
/**
* Converts an aggregate relational expression that contains only
* count(distinct) to grouping sets with count. For example select
* count(distinct department_id), count(distinct gender), count(distinct
* education_level) from employee; can be transformed to
* select
* count(case when i=1 and department_id is not null then 1 else null end) as c0,
* count(case when i=2 and gender is not null then 1 else null end) as c1,
* count(case when i=4 and education_level is not null then 1 else null end) as c2
* from (select
* grouping__id as i, department_id, gender, education_level from employee
* group by department_id, gender, education_level grouping sets
* (department_id, gender, education_level))subq;
* @throws CalciteSemanticException
*/
private RelNode convert(Aggregate aggregate, List<List<Integer>> argList, List<Integer> sourceOfForCountDistinct) throws CalciteSemanticException {
// we use this map to map the position of argList to the position of grouping set
Map<Integer, Integer> map = new HashMap<>();
List<List<Integer>> cleanArgList = new ArrayList<>();
final Aggregate groupingSets = createGroupingSets(aggregate, argList, cleanArgList, map, sourceOfForCountDistinct);
return createCount(groupingSets, argList, cleanArgList, map, sourceOfForCountDistinct);
}
private int getGroupingIdValue(List<Integer> list, List<Integer> sourceOfForCountDistinct,
int groupCount) {
int ind = IntMath.pow(2, groupCount) - 1;
for (int i : list) {
ind &= ~(1 << groupCount - sourceOfForCountDistinct.indexOf(i) - 1);
}
return ind;
}
/**
* @param aggr: the original aggregate
* @param argList: the original argList in aggregate
* @param cleanArgList: the new argList without duplicates
* @param map: the mapping from the original argList to the new argList
* @param sourceOfForCountDistinct: the sorted positions of groupset
* @return
* @throws CalciteSemanticException
*/
private RelNode createCount(Aggregate aggr, List<List<Integer>> argList,
List<List<Integer>> cleanArgList, Map<Integer, Integer> map,
List<Integer> sourceOfForCountDistinct) throws CalciteSemanticException {
List<RexNode> originalInputRefs = Lists.transform(aggr.getRowType().getFieldList(),
new Function<RelDataTypeField, RexNode>() {
@Override
public RexNode apply(RelDataTypeField input) {
return new RexInputRef(input.getIndex(), input.getType());
}
});
final List<RexNode> gbChildProjLst = Lists.newArrayList();
// for singular arg, count should not include null
// e.g., count(case when i=1 and department_id is not null then 1 else null end) as c0,
// for non-singular args, count can include null, i.e. (,) is counted as 1
for (List<Integer> list : cleanArgList) {
RexNode condition = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, originalInputRefs
.get(originalInputRefs.size() - 1), rexBuilder.makeExactLiteral(new BigDecimal(
getGroupingIdValue(list, sourceOfForCountDistinct, aggr.getGroupCount()))));
if (list.size() == 1) {
int pos = list.get(0);
RexNode notNull = rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL,
originalInputRefs.get(pos));
condition = rexBuilder.makeCall(SqlStdOperatorTable.AND, condition, notNull);
}
RexNode when = rexBuilder.makeCall(SqlStdOperatorTable.CASE, condition,
rexBuilder.makeExactLiteral(BigDecimal.ONE), rexBuilder.constantNull());
gbChildProjLst.add(when);
}
// create the project before GB
RelNode gbInputRel = HiveProject.create(aggr, gbChildProjLst, null);
// create the aggregate
List<AggregateCall> aggregateCalls = Lists.newArrayList();
RelDataType aggFnRetType = TypeConverter.convert(TypeInfoFactory.longTypeInfo,
cluster.getTypeFactory());
for (int i = 0; i < cleanArgList.size(); i++) {
AggregateCall aggregateCall = HiveCalciteUtil.createSingleArgAggCall("count", cluster,
TypeInfoFactory.longTypeInfo, i, aggFnRetType);
aggregateCalls.add(aggregateCall);
}
Aggregate aggregate = new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), gbInputRel,
false, ImmutableBitSet.of(), null, aggregateCalls);
// create the project after GB. For those repeated values, e.g., select
// count(distinct x, y), count(distinct y, x), we find the correct mapping.
if (map.isEmpty()) {
return aggregate;
} else {
List<RexNode> originalAggrRefs = Lists.transform(aggregate.getRowType().getFieldList(),
new Function<RelDataTypeField, RexNode>() {
@Override
public RexNode apply(RelDataTypeField input) {
return new RexInputRef(input.getIndex(), input.getType());
}
});
final List<RexNode> projLst = Lists.newArrayList();
int index = 0;
for (int i = 0; i < argList.size(); i++) {
if (map.containsKey(i)) {
projLst.add(originalAggrRefs.get(map.get(i)));
} else {
projLst.add(originalAggrRefs.get(index++));
}
}
return HiveProject.create(aggregate, projLst, null);
}
}
/**
* @param aggregate: the original aggregate
* @param argList: the original argList in aggregate
* @param cleanArgList: the new argList without duplicates
* @param map: the mapping from the original argList to the new argList
* @param sourceOfForCountDistinct: the sorted positions of groupset
* @return
*/
private Aggregate createGroupingSets(Aggregate aggregate, List<List<Integer>> argList,
List<List<Integer>> cleanArgList, Map<Integer, Integer> map,
List<Integer> sourceOfForCountDistinct) {
final ImmutableBitSet groupSet = ImmutableBitSet.of(sourceOfForCountDistinct);
final List<ImmutableBitSet> origGroupSets = new ArrayList<>();
for (int i = 0; i < argList.size(); i++) {
List<Integer> list = argList.get(i);
ImmutableBitSet bitSet = ImmutableBitSet.of(list);
int prev = origGroupSets.indexOf(bitSet);
if (prev == -1) {
origGroupSets.add(bitSet);
cleanArgList.add(list);
} else {
map.put(i, prev);
}
}
// Calcite expects the grouping sets sorted and without duplicates
Collections.sort(origGroupSets, ImmutableBitSet.COMPARATOR);
List<AggregateCall> aggregateCalls = new ArrayList<AggregateCall>();
// Create GroupingID column
AggregateCall aggCall = AggregateCall.create(HiveGroupingID.INSTANCE, false,
new ImmutableList.Builder<Integer>().build(), -1, this.cluster.getTypeFactory()
.createSqlType(SqlTypeName.INTEGER), HiveGroupingID.INSTANCE.getName());
aggregateCalls.add(aggCall);
return new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION),
aggregate.getInput(), true, groupSet, origGroupSets, aggregateCalls);
}
/**
* Returns the number of count DISTINCT
*
* @return the number of count DISTINCT
*/
private int getNumCountDistinctCall(Aggregate hiveAggregate) {
int cnt = 0;
for (AggregateCall aggCall : hiveAggregate.getAggCallList()) {
if (aggCall.isDistinct() && (aggCall.getAggregation().getName().equalsIgnoreCase("count"))) {
cnt++;
}
}
return cnt;
}
/**
* Converts an aggregate relational expression that contains just one
* distinct aggregate function (or perhaps several over the same arguments)
* and no non-distinct aggregate functions.
*/
private RelNode convertMonopole(
Aggregate aggregate,
List<Integer> argList) {
// For example,
// SELECT deptno, COUNT(DISTINCT sal), SUM(DISTINCT sal)
// FROM emp
// GROUP BY deptno
//
// becomes
//
// SELECT deptno, COUNT(distinct_sal), SUM(distinct_sal)
// FROM (
// SELECT DISTINCT deptno, sal AS distinct_sal
// FROM EMP GROUP BY deptno)
// GROUP BY deptno
// Project the columns of the GROUP BY plus the arguments
// to the agg function.
Map<Integer, Integer> sourceOf = new HashMap<Integer, Integer>();
final Aggregate distinct =
createSelectDistinct(aggregate, argList, sourceOf);
// Create an aggregate on top, with the new aggregate list.
final List<AggregateCall> newAggCalls =
Lists.newArrayList(aggregate.getAggCallList());
rewriteAggCalls(newAggCalls, argList, sourceOf);
final int cardinality = aggregate.getGroupSet().cardinality();
return aggregate.copy(aggregate.getTraitSet(), distinct,
aggregate.indicator, ImmutableBitSet.range(cardinality), null,
newAggCalls);
}
private static void rewriteAggCalls(
List<AggregateCall> newAggCalls,
List<Integer> argList,
Map<Integer, Integer> sourceOf) {
// Rewrite the agg calls. Each distinct agg becomes a non-distinct call
// to the corresponding field from the right; for example,
// "COUNT(DISTINCT e.sal)" becomes "COUNT(distinct_e.sal)".
for (int i = 0; i < newAggCalls.size(); i++) {
final AggregateCall aggCall = newAggCalls.get(i);
// Ignore agg calls which are not distinct or have the wrong set
// arguments. If we're rewriting aggs whose args are {sal}, we will
// rewrite COUNT(DISTINCT sal) and SUM(DISTINCT sal) but ignore
// COUNT(DISTINCT gender) or SUM(sal).
if (!aggCall.isDistinct()) {
continue;
}
if (!aggCall.getArgList().equals(argList)) {
continue;
}
// Re-map arguments.
final int argCount = aggCall.getArgList().size();
final List<Integer> newArgs = new ArrayList<Integer>(argCount);
for (int j = 0; j < argCount; j++) {
final Integer arg = aggCall.getArgList().get(j);
newArgs.add(sourceOf.get(arg));
}
final AggregateCall newAggCall =
new AggregateCall(
aggCall.getAggregation(),
false,
newArgs,
aggCall.getType(),
aggCall.getName());
newAggCalls.set(i, newAggCall);
}
}
/**
* Given an {@link org.apache.calcite.rel.logical.LogicalAggregate}
* and the ordinals of the arguments to a
* particular call to an aggregate function, creates a 'select distinct'
* relational expression which projects the group columns and those
* arguments but nothing else.
*
* <p>For example, given
*
* <blockquote>
* <pre>select f0, count(distinct f1), count(distinct f2)
* from t group by f0</pre>
* </blockquote>
*
* and the arglist
*
* <blockquote>{2}</blockquote>
*
* returns
*
* <blockquote>
* <pre>select distinct f0, f2 from t</pre>
* </blockquote>
*
* '
*
* <p>The <code>sourceOf</code> map is populated with the source of each
* column; in this case sourceOf.get(0) = 0, and sourceOf.get(1) = 2.</p>
*
* @param aggregate Aggregate relational expression
* @param argList Ordinals of columns to make distinct
* @param sourceOf Out parameter, is populated with a map of where each
* output field came from
* @return Aggregate relational expression which projects the required
* columns
*/
private static Aggregate createSelectDistinct(
Aggregate aggregate,
List<Integer> argList,
Map<Integer, Integer> sourceOf) {
final List<Pair<RexNode, String>> projects =
new ArrayList<Pair<RexNode, String>>();
final RelNode child = aggregate.getInput();
final List<RelDataTypeField> childFields =
child.getRowType().getFieldList();
for (int i : aggregate.getGroupSet()) {
sourceOf.put(i, projects.size());
projects.add(RexInputRef.of2(i, childFields));
}
for (Integer arg : argList) {
if (sourceOf.get(arg) != null) {
continue;
}
sourceOf.put(arg, projects.size());
projects.add(RexInputRef.of2(arg, childFields));
}
final RelNode project =
projFactory.createProject(child, Pair.left(projects), Pair.right(projects));
// Get the distinct values of the GROUP BY fields and the arguments
// to the agg functions.
return aggregate.copy(aggregate.getTraitSet(), project, false,
ImmutableBitSet.range(projects.size()),
null, ImmutableList.<AggregateCall>of());
}
}
// End AggregateExpandDistinctAggregatesRule.java