/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.SetOp;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexOver;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ReflectUtil;
import org.apache.calcite.util.ReflectiveVisitor;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
import com.google.common.collect.ImmutableList;
/**
* This class infers the order in Aggregate columns and the order of conjuncts
* in a Join condition that might be more beneficial to avoid additional sort
* stages. The only visible change is that order of join conditions might change.
* Further, Aggregate operators might get annotated with order in which Aggregate
* columns should be generated when we transform the operator tree into AST or
* Hive operator tree.
*/
public class HiveRelColumnsAlignment implements ReflectiveVisitor {
private final ReflectUtil.MethodDispatcher<RelNode> alignDispatcher;
private final RelBuilder relBuilder;
/**
* Creates a HiveRelColumnsAlignment.
*/
public HiveRelColumnsAlignment(RelBuilder relBuilder) {
this.relBuilder = relBuilder;
this.alignDispatcher =
ReflectUtil.createMethodDispatcher(
RelNode.class,
this,
"align",
RelNode.class,
List.class);
}
/**
* Execute the logic in this class. In particular, make a top-down traversal of the tree
* and annotate and recreate appropiate operators.
*/
public RelNode align(RelNode root) {
final RelNode newRoot = dispatchAlign(root, ImmutableList.<RelFieldCollation>of());
return newRoot;
}
protected final RelNode dispatchAlign(RelNode node, List<RelFieldCollation> collations) {
return alignDispatcher.invoke(node, collations);
}
public RelNode align(Aggregate rel, List<RelFieldCollation> collations) {
// 1) We extract the group by positions that are part of the collations and
// sort them so they respect it
LinkedHashSet<Integer> aggregateColumnsOrder = new LinkedHashSet<>();
ImmutableList.Builder<RelFieldCollation> propagateCollations = ImmutableList.builder();
if (!rel.indicator && !collations.isEmpty()) {
for (RelFieldCollation c : collations) {
if (c.getFieldIndex() < rel.getGroupCount()) {
// Group column found
if (aggregateColumnsOrder.add(c.getFieldIndex())) {
propagateCollations.add(c.copy(rel.getGroupSet().nth(c.getFieldIndex())));
}
}
}
}
for (int i = 0; i < rel.getGroupCount(); i++) {
if (!aggregateColumnsOrder.contains(i)) {
// Not included in the input collations, but can be propagated as this Aggregate
// will enforce it
propagateCollations.add(new RelFieldCollation(rel.getGroupSet().nth(i)));
}
}
// 2) We propagate
final RelNode child = dispatchAlign(rel.getInput(), propagateCollations.build());
// 3) We annotate the Aggregate operator with this info
final HiveAggregate newAggregate = (HiveAggregate) rel.copy(rel.getTraitSet(),
ImmutableList.of(child));
newAggregate.setAggregateColumnsOrder(aggregateColumnsOrder);
return newAggregate;
}
public RelNode align(Join rel, List<RelFieldCollation> collations) {
ImmutableList.Builder<RelFieldCollation> propagateCollationsLeft = ImmutableList.builder();
ImmutableList.Builder<RelFieldCollation> propagateCollationsRight = ImmutableList.builder();
final int nLeftColumns = rel.getLeft().getRowType().getFieldList().size();
Map<Integer,RexNode> idxToConjuncts = new HashMap<>();
Map<Integer,Integer> refToRef = new HashMap<>();
// 1) We extract the conditions that can be useful
List<RexNode> conjuncts = new ArrayList<>();
List<RexNode> otherConjuncts = new ArrayList<>();
for (RexNode conj : RelOptUtil.conjunctions(rel.getCondition())) {
if (conj.getKind() != SqlKind.EQUALS) {
otherConjuncts.add(conj);
continue;
}
// TODO: Currently we only support EQUAL operator on two references.
// We might extend the logic to support other (order-preserving)
// UDFs here.
RexCall equals = (RexCall) conj;
if (!(equals.getOperands().get(0) instanceof RexInputRef) ||
!(equals.getOperands().get(1) instanceof RexInputRef)) {
otherConjuncts.add(conj);
continue;
}
RexInputRef ref0 = (RexInputRef) equals.getOperands().get(0);
RexInputRef ref1 = (RexInputRef) equals.getOperands().get(1);
if ((ref0.getIndex() < nLeftColumns && ref1.getIndex() >= nLeftColumns) ||
(ref1.getIndex() < nLeftColumns && ref0.getIndex() >= nLeftColumns)) {
// We made sure the references are for different join inputs
idxToConjuncts.put(ref0.getIndex(), equals);
idxToConjuncts.put(ref1.getIndex(), equals);
refToRef.put(ref0.getIndex(), ref1.getIndex());
refToRef.put(ref1.getIndex(), ref0.getIndex());
} else {
otherConjuncts.add(conj);
}
}
// 2) We extract the collation for this operator and the collations
// that we will propagate to the inputs of the join
for (RelFieldCollation c : collations) {
RexNode equals = idxToConjuncts.get(c.getFieldIndex());
if (equals != null) {
conjuncts.add(equals);
idxToConjuncts.remove(c.getFieldIndex());
idxToConjuncts.remove(refToRef.get(c.getFieldIndex()));
if (c.getFieldIndex() < nLeftColumns) {
propagateCollationsLeft.add(c.copy(c.getFieldIndex()));
propagateCollationsRight.add(c.copy(refToRef.get(c.getFieldIndex()) - nLeftColumns));
} else {
propagateCollationsLeft.add(c.copy(refToRef.get(c.getFieldIndex())));
propagateCollationsRight.add(c.copy(c.getFieldIndex() - nLeftColumns));
}
}
}
final Set<RexNode> visited = new HashSet<>();
for (Entry<Integer,RexNode> e : idxToConjuncts.entrySet()) {
if (visited.add(e.getValue())) {
// Not included in the input collations, but can be propagated as this Join
// might enforce it
conjuncts.add(e.getValue());
if (e.getKey() < nLeftColumns) {
propagateCollationsLeft.add(new RelFieldCollation(e.getKey()));
propagateCollationsRight.add(new RelFieldCollation(refToRef.get(e.getKey()) - nLeftColumns));
} else {
propagateCollationsLeft.add(new RelFieldCollation(refToRef.get(e.getKey())));
propagateCollationsRight.add(new RelFieldCollation(e.getKey() - nLeftColumns));
}
}
}
conjuncts.addAll(otherConjuncts);
// 3) We propagate
final RelNode newLeftInput = dispatchAlign(rel.getLeft(), propagateCollationsLeft.build());
final RelNode newRightInput = dispatchAlign(rel.getRight(), propagateCollationsRight.build());
// 4) We change the Join operator to reflect this info
final RelNode newJoin = rel.copy(rel.getTraitSet(), RexUtil.composeConjunction(
relBuilder.getRexBuilder(), conjuncts, false), newLeftInput, newRightInput,
rel.getJoinType(), rel.isSemiJoinDone());
return newJoin;
}
public RelNode align(SetOp rel, List<RelFieldCollation> collations) {
ImmutableList.Builder<RelNode> newInputs = new ImmutableList.Builder<>();
for (RelNode input : rel.getInputs()) {
newInputs.add(dispatchAlign(input, collations));
}
return rel.copy(rel.getTraitSet(), newInputs.build());
}
public RelNode align(Project rel, List<RelFieldCollation> collations) {
// 1) We extract the collations indices
boolean containsWindowing = false;
for (RexNode childExp : rel.getChildExps()) {
if (childExp instanceof RexOver) {
// TODO: support propagation for partitioning/ordering in windowing
containsWindowing = true;
break;
}
}
ImmutableList.Builder<RelFieldCollation> propagateCollations = ImmutableList.builder();
if (!containsWindowing) {
for (RelFieldCollation c : collations) {
RexNode rexNode = rel.getChildExps().get(c.getFieldIndex());
if (rexNode instanceof RexInputRef) {
int newIdx = ((RexInputRef) rexNode).getIndex();
propagateCollations.add(c.copy((newIdx)));
}
}
}
// 2) We propagate
final RelNode child = dispatchAlign(rel.getInput(), propagateCollations.build());
// 3) Return new Project
return rel.copy(rel.getTraitSet(), ImmutableList.of(child));
}
public RelNode align(Filter rel, List<RelFieldCollation> collations) {
final RelNode child = dispatchAlign(rel.getInput(), collations);
return rel.copy(rel.getTraitSet(), ImmutableList.of(child));
}
public RelNode align(Sort rel, List<RelFieldCollation> collations) {
final RelNode child = dispatchAlign(rel.getInput(), rel.collation.getFieldCollations());
return rel.copy(rel.getTraitSet(), ImmutableList.of(child));
}
// Catch-all rule when none of the others apply.
public RelNode align(RelNode rel, List<RelFieldCollation> collations) {
ImmutableList.Builder<RelNode> newInputs = new ImmutableList.Builder<>();
for (RelNode input : rel.getInputs()) {
newInputs.add(dispatchAlign(input, ImmutableList.<RelFieldCollation>of()));
}
return rel.copy(rel.getTraitSet(), newInputs.build());
}
}