/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.hive.ql.optimizer.calcite.rules; import java.util.ArrayList; import java.util.List; import java.util.Set; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.core.RelFactories.ProjectFactory; import org.apache.calcite.rel.rules.JoinCommuteRule; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexUtil; import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.Pair; import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil.JoinPredicateInfo; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelOptUtil; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; /** * Rule that merges a join with multijoin/join children if * the equi compared the same set of input columns. */ public class HiveJoinToMultiJoinRule extends RelOptRule { public static final HiveJoinToMultiJoinRule INSTANCE = new HiveJoinToMultiJoinRule(HiveJoin.class, HiveRelFactories.HIVE_PROJECT_FACTORY); private final ProjectFactory projectFactory; private static transient final Logger LOG = LoggerFactory.getLogger(HiveJoinToMultiJoinRule.class); //~ Constructors ----------------------------------------------------------- /** * Creates a JoinToMultiJoinRule. */ public HiveJoinToMultiJoinRule(Class<? extends Join> clazz, ProjectFactory projectFactory) { super(operand(clazz, operand(RelNode.class, any()), operand(RelNode.class, any()))); this.projectFactory = projectFactory; } //~ Methods ---------------------------------------------------------------- @Override public void onMatch(RelOptRuleCall call) { final HiveJoin join = call.rel(0); final RelNode left = call.rel(1); final RelNode right = call.rel(2); // 1. We try to merge this join with the left child RelNode multiJoin = mergeJoin(join, left, right); if (multiJoin != null) { call.transformTo(multiJoin); return; } // 2. If we cannot, we swap the inputs so we can try // to merge it with its right child RelNode swapped = JoinCommuteRule.swap(join, true); assert swapped != null; // The result of the swapping operation is either // i) a Project or, // ii) if the project is trivial, a raw join final HiveJoin newJoin; Project topProject = null; if (swapped instanceof HiveJoin) { newJoin = (HiveJoin) swapped; } else { topProject = (Project) swapped; newJoin = (HiveJoin) swapped.getInput(0); } // 3. We try to merge the join with the right child multiJoin = mergeJoin(newJoin, right, left); if (multiJoin != null) { if (topProject != null) { multiJoin = projectFactory.createProject(multiJoin, topProject.getChildExps(), topProject.getRowType().getFieldNames()); } call.transformTo(multiJoin); return; } } // This method tries to merge the join with its left child. The left // child should be a join for this to happen. private static RelNode mergeJoin(HiveJoin join, RelNode left, RelNode right) { final RexBuilder rexBuilder = join.getCluster().getRexBuilder(); // We check whether the join can be combined with any of its children final List<RelNode> newInputs = Lists.newArrayList(); final List<RexNode> newJoinCondition = Lists.newArrayList(); final List<Pair<Integer,Integer>> joinInputs = Lists.newArrayList(); final List<JoinRelType> joinTypes = Lists.newArrayList(); final List<RexNode> joinFilters = Lists.newArrayList(); // Left child if (left instanceof HiveJoin || left instanceof HiveMultiJoin) { final RexNode leftCondition; final List<Pair<Integer,Integer>> leftJoinInputs; final List<JoinRelType> leftJoinTypes; final List<RexNode> leftJoinFilters; boolean combinable; if (left instanceof HiveJoin) { HiveJoin hj = (HiveJoin) left; leftCondition = hj.getCondition(); leftJoinInputs = ImmutableList.of(Pair.of(0, 1)); leftJoinTypes = ImmutableList.of(hj.getJoinType()); leftJoinFilters = ImmutableList.of(hj.getJoinFilter()); try { combinable = isCombinableJoin(join, hj); } catch (CalciteSemanticException e) { LOG.trace("Failed to merge join-join", e); combinable = false; } } else { HiveMultiJoin hmj = (HiveMultiJoin) left; leftCondition = hmj.getCondition(); leftJoinInputs = hmj.getJoinInputs(); leftJoinTypes = hmj.getJoinTypes(); leftJoinFilters = hmj.getJoinFilters(); try { combinable = isCombinableJoin(join, hmj); } catch (CalciteSemanticException e) { LOG.trace("Failed to merge join-multijoin", e); combinable = false; } } if (combinable) { newJoinCondition.add(leftCondition); for (int i = 0; i < leftJoinInputs.size(); i++) { joinInputs.add(leftJoinInputs.get(i)); joinTypes.add(leftJoinTypes.get(i)); joinFilters.add(leftJoinFilters.get(i)); } newInputs.addAll(left.getInputs()); } else { // The join operation in the child is not on the same keys return null; } } else { // The left child is not a join or multijoin operator return null; } final int numberLeftInputs = newInputs.size(); // Right child newInputs.add(right); // If we cannot combine any of the children, we bail out newJoinCondition.add(join.getCondition()); if (newJoinCondition.size() == 1) { return null; } final List<RelDataTypeField> systemFieldList = ImmutableList.of(); List<List<RexNode>> joinKeyExprs = new ArrayList<List<RexNode>>(); List<Integer> filterNulls = new ArrayList<Integer>(); for (int i=0; i<newInputs.size(); i++) { joinKeyExprs.add(new ArrayList<RexNode>()); } RexNode filters; try { filters = HiveRelOptUtil.splitHiveJoinCondition(systemFieldList, newInputs, join.getCondition(), joinKeyExprs, filterNulls, null); } catch (CalciteSemanticException e) { LOG.trace("Failed to merge joins", e); return null; } ImmutableBitSet.Builder keysInInputsBuilder = ImmutableBitSet.builder(); for (int i=0; i<newInputs.size(); i++) { List<RexNode> partialCondition = joinKeyExprs.get(i); if (!partialCondition.isEmpty()) { keysInInputsBuilder.set(i); } } // If we cannot merge, we bail out ImmutableBitSet keysInInputs = keysInInputsBuilder.build(); ImmutableBitSet leftReferencedInputs = keysInInputs.intersect(ImmutableBitSet.range(numberLeftInputs)); ImmutableBitSet rightReferencedInputs = keysInInputs.intersect(ImmutableBitSet.range(numberLeftInputs, newInputs.size())); if (join.getJoinType() != JoinRelType.INNER && (leftReferencedInputs.cardinality() > 1 || rightReferencedInputs.cardinality() > 1)) { return null; } // Otherwise, we add to the join specs if (join.getJoinType() != JoinRelType.INNER) { int leftInput = keysInInputs.nextSetBit(0); int rightInput = keysInInputs.nextSetBit(numberLeftInputs); joinInputs.add(Pair.of(leftInput, rightInput)); joinTypes.add(join.getJoinType()); joinFilters.add(filters); } else { for (int i : leftReferencedInputs) { for (int j : rightReferencedInputs) { joinInputs.add(Pair.of(i, j)); joinTypes.add(join.getJoinType()); joinFilters.add(filters); } } } // We can now create a multijoin operator RexNode newCondition = RexUtil.flatten(rexBuilder, RexUtil.composeConjunction(rexBuilder, newJoinCondition, false)); List<RelNode> newInputsArray = Lists.newArrayList(newInputs); JoinPredicateInfo joinPredInfo = null; try { joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(newInputsArray, systemFieldList, newCondition); } catch (CalciteSemanticException e) { throw new RuntimeException(e); } // If the number of joins < number of input tables-1, this is not a star join. if (joinPredInfo.getEquiJoinPredicateElements().size() < newInputs.size()-1) { return null; } // Validate that the multi-join is a valid star join before returning it. for (int i=0; i<newInputs.size(); i++) { List<RexNode> joinKeys = null; for (int j = 0; j < joinPredInfo.getEquiJoinPredicateElements().size(); j++) { List<RexNode> currJoinKeys = joinPredInfo. getEquiJoinPredicateElements().get(j).getJoinExprs(i); if (currJoinKeys.isEmpty()) { continue; } if (joinKeys == null) { joinKeys = currJoinKeys; } else { // If we join on different keys on different tables, we can no longer apply // multi-join conversion as this is no longer a valid star join. // Bail out if this is the case. if (!joinKeys.containsAll(currJoinKeys) || !currJoinKeys.containsAll(joinKeys)) { return null; } } } } return new HiveMultiJoin( join.getCluster(), newInputsArray, newCondition, join.getRowType(), joinInputs, joinTypes, joinFilters, joinPredInfo); } /* * Returns true if the join conditions execute over the same keys */ private static boolean isCombinableJoin(HiveJoin join, HiveJoin leftChildJoin) throws CalciteSemanticException { final JoinPredicateInfo joinPredInfo = HiveCalciteUtil.JoinPredicateInfo. constructJoinPredicateInfo(join, join.getCondition()); final JoinPredicateInfo leftChildJoinPredInfo = HiveCalciteUtil.JoinPredicateInfo. constructJoinPredicateInfo(leftChildJoin, leftChildJoin.getCondition()); return isCombinablePredicate(joinPredInfo, leftChildJoinPredInfo, leftChildJoin.getInputs().size()); } /* * Returns true if the join conditions execute over the same keys */ private static boolean isCombinableJoin(HiveJoin join, HiveMultiJoin leftChildJoin) throws CalciteSemanticException { final JoinPredicateInfo joinPredInfo = HiveCalciteUtil.JoinPredicateInfo. constructJoinPredicateInfo(join, join.getCondition()); final JoinPredicateInfo leftChildJoinPredInfo = HiveCalciteUtil.JoinPredicateInfo. constructJoinPredicateInfo(leftChildJoin, leftChildJoin.getCondition()); return isCombinablePredicate(joinPredInfo, leftChildJoinPredInfo, leftChildJoin.getInputs().size()); } /* * To be able to combine a parent join and its left input join child, * the left keys over which the parent join is executed need to be the same * than those of the child join. * Thus, we iterate over the different inputs of the child, checking if the * keys of the parent are the same */ private static boolean isCombinablePredicate(JoinPredicateInfo joinPredInfo, JoinPredicateInfo leftChildJoinPredInfo, int noLeftChildInputs) throws CalciteSemanticException { Set<Integer> keys = joinPredInfo.getProjsJoinKeysInChildSchema(0); if (keys.isEmpty()) { return false; } for (int i = 0; i < noLeftChildInputs; i++) { if (keys.equals(leftChildJoinPredInfo.getProjsJoinKeysInJoinSchema(i))) { return true; } } return false; } }