/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.hive.ql.optimizer.calcite; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.plan.RelOptUtil.InputFinder; import org.apache.calcite.plan.RelOptUtil.InputReferencedVisitor; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.RelFactories.ProjectFactory; import org.apache.calcite.rel.core.Sort; import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexCorrelVariable; import org.apache.calcite.rex.RexDynamicParam; import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexLocalRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexOver; import org.apache.calcite.rex.RexPatternFieldRef; import org.apache.calcite.rex.RexRangeRef; import org.apache.calcite.rex.RexSubQuery; import org.apache.calcite.rex.RexUtil; import org.apache.calcite.rex.RexVisitor; import org.apache.calcite.rex.RexVisitorImpl; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.validate.SqlValidatorUtil; import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.Pair; import org.apache.calcite.util.Util; import org.apache.hadoop.hive.metastore.api.FieldSchema; import org.apache.hadoop.hive.ql.exec.FunctionInfo; import org.apache.hadoop.hive.ql.exec.FunctionRegistry; import org.apache.hadoop.hive.ql.metadata.VirtualColumn; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableFunctionScan; import org.apache.hadoop.hive.ql.optimizer.calcite.translator.ExprNodeConverter; import org.apache.hadoop.hive.ql.optimizer.calcite.translator.SqlFunctionConverter; import org.apache.hadoop.hive.ql.optimizer.calcite.translator.TypeConverter; import org.apache.hadoop.hive.ql.parse.ASTNode; import org.apache.hadoop.hive.ql.parse.HiveParser; import org.apache.hadoop.hive.ql.parse.ParseUtils; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.base.Function; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap.Builder; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; /** * Generic utility functions needed for Calcite based Hive CBO. */ public class HiveCalciteUtil { private static final Logger LOG = LoggerFactory.getLogger(HiveCalciteUtil.class); /** * Get list of virtual columns from the given list of projections. * <p> * * @param exps * list of rex nodes representing projections * @return List of Virtual Columns, will not be null. */ public static List<Integer> getVirtualCols(List<? extends RexNode> exps) { List<Integer> vCols = new ArrayList<Integer>(); for (int i = 0; i < exps.size(); i++) { if (!(exps.get(i) instanceof RexInputRef)) { vCols.add(i); } } return vCols; } public static boolean validateASTForUnsupportedTokens(ASTNode ast) { if (ParseUtils.containsTokenOfType(ast, HiveParser.TOK_CHARSETLITERAL, HiveParser.TOK_TABLESPLITSAMPLE)) { return false; } else { return true; } } public static List<RexNode> getProjsFromBelowAsInputRef(final RelNode rel) { List<RexNode> projectList = Lists.transform(rel.getRowType().getFieldList(), new Function<RelDataTypeField, RexNode>() { @Override public RexNode apply(RelDataTypeField field) { return rel.getCluster().getRexBuilder().makeInputRef(field.getType(), field.getIndex()); } }); return projectList; } public static List<Integer> translateBitSetToProjIndx(ImmutableBitSet projBitSet) { List<Integer> projIndxLst = new ArrayList<Integer>(); for (int i = 0; i < projBitSet.length(); i++) { if (projBitSet.get(i)) { projIndxLst.add(i); } } return projIndxLst; } /** * Push any equi join conditions that are not column references as Projections * on top of the children. * * @param factory * Project factory to use. * @param inputRels * inputs to a join * @param leftJoinKeys * expressions for LHS of join key * @param rightJoinKeys * expressions for RHS of join key * @param systemColCount * number of system columns, usually zero. These columns are * projected at the leading edge of the output row. * @param leftKeys * on return this contains the join key positions from the new * project rel on the LHS. * @param rightKeys * on return this contains the join key positions from the new * project rel on the RHS. * @return the join condition after the equi expressions pushed down. */ public static RexNode projectNonColumnEquiConditions(ProjectFactory factory, RelNode[] inputRels, List<RexNode> leftJoinKeys, List<RexNode> rightJoinKeys, int systemColCount, List<Integer> leftKeys, List<Integer> rightKeys) { RelNode leftRel = inputRels[0]; RelNode rightRel = inputRels[1]; RexBuilder rexBuilder = leftRel.getCluster().getRexBuilder(); RexNode outJoinCond = null; int origLeftInputSize = leftRel.getRowType().getFieldCount(); int origRightInputSize = rightRel.getRowType().getFieldCount(); List<RexNode> newLeftFields = new ArrayList<RexNode>(); List<String> newLeftFieldNames = new ArrayList<String>(); List<RexNode> newRightFields = new ArrayList<RexNode>(); List<String> newRightFieldNames = new ArrayList<String>(); int leftKeyCount = leftJoinKeys.size(); int i; for (i = 0; i < origLeftInputSize; i++) { final RelDataTypeField field = leftRel.getRowType().getFieldList().get(i); newLeftFields.add(rexBuilder.makeInputRef(field.getType(), i)); newLeftFieldNames.add(field.getName()); } for (i = 0; i < origRightInputSize; i++) { final RelDataTypeField field = rightRel.getRowType().getFieldList().get(i); newRightFields.add(rexBuilder.makeInputRef(field.getType(), i)); newRightFieldNames.add(field.getName()); } ImmutableBitSet.Builder origColEqCondsPosBuilder = ImmutableBitSet.builder(); int newKeyCount = 0; List<Pair<Integer, Integer>> origColEqConds = new ArrayList<Pair<Integer, Integer>>(); for (i = 0; i < leftKeyCount; i++) { RexNode leftKey = leftJoinKeys.get(i); RexNode rightKey = rightJoinKeys.get(i); if (leftKey instanceof RexInputRef && rightKey instanceof RexInputRef) { origColEqConds.add(Pair.of(((RexInputRef) leftKey).getIndex(), ((RexInputRef) rightKey).getIndex())); origColEqCondsPosBuilder.set(i); } else { newLeftFields.add(leftKey); newLeftFieldNames.add(null); newRightFields.add(rightKey); newRightFieldNames.add(null); newKeyCount++; } } ImmutableBitSet origColEqCondsPos = origColEqCondsPosBuilder.build(); for (i = 0; i < origColEqConds.size(); i++) { Pair<Integer, Integer> p = origColEqConds.get(i); int condPos = origColEqCondsPos.nth(i); RexNode leftKey = leftJoinKeys.get(condPos); RexNode rightKey = rightJoinKeys.get(condPos); leftKeys.add(p.left); rightKeys.add(p.right); RexNode cond = rexBuilder.makeCall( SqlStdOperatorTable.EQUALS, rexBuilder.makeInputRef(leftKey.getType(), systemColCount + p.left), rexBuilder.makeInputRef(rightKey.getType(), systemColCount + origLeftInputSize + newKeyCount + p.right)); if (outJoinCond == null) { outJoinCond = cond; } else { outJoinCond = rexBuilder.makeCall(SqlStdOperatorTable.AND, outJoinCond, cond); } } if (newKeyCount == 0) { return outJoinCond; } int newLeftOffset = systemColCount + origLeftInputSize; int newRightOffset = systemColCount + origLeftInputSize + origRightInputSize + newKeyCount; for (i = 0; i < newKeyCount; i++) { leftKeys.add(origLeftInputSize + i); rightKeys.add(origRightInputSize + i); RexNode cond = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, rexBuilder.makeInputRef(newLeftFields.get(origLeftInputSize + i).getType(), newLeftOffset + i), rexBuilder.makeInputRef(newRightFields.get(origRightInputSize + i).getType(), newRightOffset + i)); if (outJoinCond == null) { outJoinCond = cond; } else { outJoinCond = rexBuilder.makeCall(SqlStdOperatorTable.AND, outJoinCond, cond); } } // added project if need to produce new keys than the original input // fields if (newKeyCount > 0) { leftRel = factory.createProject(leftRel, newLeftFields, SqlValidatorUtil.uniquify(newLeftFieldNames)); rightRel = factory.createProject(rightRel, newRightFields, SqlValidatorUtil.uniquify(newRightFieldNames)); } inputRels[0] = leftRel; inputRels[1] = rightRel; return outJoinCond; } /** * JoinPredicateInfo represents Join condition; JoinPredicate Info uses * JoinLeafPredicateInfo to represent individual conjunctive elements in the * predicate.<br> * JoinPredicateInfo = JoinLeafPredicateInfo1 and JoinLeafPredicateInfo2...<br> * <p> * JoinPredicateInfo:<br> * 1. preserves the order of conjuctive elements for * equi-join(equiJoinPredicateElements)<br> * 2. Stores set of projection indexes from left and right child which is part * of equi join keys; the indexes are both in child and Join node schema.<br> * 3. Keeps a map of projection indexes that are part of join keys to list of * conjuctive elements(JoinLeafPredicateInfo) that uses them. * */ public static class JoinPredicateInfo { private final ImmutableList<JoinLeafPredicateInfo> nonEquiJoinPredicateElements; private final ImmutableList<JoinLeafPredicateInfo> equiJoinPredicateElements; private final ImmutableList<Set<Integer>> projsJoinKeysInChildSchema; private final ImmutableList<Set<Integer>> projsJoinKeysInJoinSchema; private final ImmutableMap<Integer, ImmutableList<JoinLeafPredicateInfo>> mapOfProjIndxInJoinSchemaToLeafPInfo; public JoinPredicateInfo(List<JoinLeafPredicateInfo> nonEquiJoinPredicateElements, List<JoinLeafPredicateInfo> equiJoinPredicateElements, List<Set<Integer>> projsJoinKeysInChildSchema, List<Set<Integer>> projsJoinKeysInJoinSchema, Map<Integer, ImmutableList<JoinLeafPredicateInfo>> mapOfProjIndxInJoinSchemaToLeafPInfo) { this.nonEquiJoinPredicateElements = ImmutableList.copyOf(nonEquiJoinPredicateElements); this.equiJoinPredicateElements = ImmutableList.copyOf(equiJoinPredicateElements); this.projsJoinKeysInChildSchema = ImmutableList .copyOf(projsJoinKeysInChildSchema); this.projsJoinKeysInJoinSchema = ImmutableList .copyOf(projsJoinKeysInJoinSchema); this.mapOfProjIndxInJoinSchemaToLeafPInfo = ImmutableMap .copyOf(mapOfProjIndxInJoinSchemaToLeafPInfo); } public List<JoinLeafPredicateInfo> getNonEquiJoinPredicateElements() { return this.nonEquiJoinPredicateElements; } public List<JoinLeafPredicateInfo> getEquiJoinPredicateElements() { return this.equiJoinPredicateElements; } public Set<Integer> getProjsFromLeftPartOfJoinKeysInChildSchema() { assert projsJoinKeysInChildSchema.size() == 2; return this.projsJoinKeysInChildSchema.get(0); } public Set<Integer> getProjsFromRightPartOfJoinKeysInChildSchema() { assert projsJoinKeysInChildSchema.size() == 2; return this.projsJoinKeysInChildSchema.get(1); } public Set<Integer> getProjsJoinKeysInChildSchema(int i) { return this.projsJoinKeysInChildSchema.get(i); } /** * NOTE: Join Schema = left Schema + (right Schema offset by * left.fieldcount). Hence its ok to return projections from left in child * schema. */ public Set<Integer> getProjsFromLeftPartOfJoinKeysInJoinSchema() { assert projsJoinKeysInJoinSchema.size() == 2; return this.projsJoinKeysInJoinSchema.get(0); } public Set<Integer> getProjsFromRightPartOfJoinKeysInJoinSchema() { assert projsJoinKeysInJoinSchema.size() == 2; return this.projsJoinKeysInJoinSchema.get(1); } public Set<Integer> getProjsJoinKeysInJoinSchema(int i) { return this.projsJoinKeysInJoinSchema.get(i); } public Map<Integer, ImmutableList<JoinLeafPredicateInfo>> getMapOfProjIndxToLeafPInfo() { return this.mapOfProjIndxInJoinSchemaToLeafPInfo; } public static JoinPredicateInfo constructJoinPredicateInfo(Join j) throws CalciteSemanticException { return constructJoinPredicateInfo(j, j.getCondition()); } public static JoinPredicateInfo constructJoinPredicateInfo(HiveMultiJoin mj) throws CalciteSemanticException { return constructJoinPredicateInfo(mj, mj.getCondition()); } public static JoinPredicateInfo constructJoinPredicateInfo(Join j, RexNode predicate) throws CalciteSemanticException { return constructJoinPredicateInfo(j.getInputs(), j.getSystemFieldList(), predicate); } public static JoinPredicateInfo constructJoinPredicateInfo(HiveMultiJoin mj, RexNode predicate) throws CalciteSemanticException { final List<RelDataTypeField> systemFieldList = ImmutableList.of(); return constructJoinPredicateInfo(mj.getInputs(), systemFieldList, predicate); } public static JoinPredicateInfo constructJoinPredicateInfo(List<RelNode> inputs, List<RelDataTypeField> systemFieldList, RexNode predicate) throws CalciteSemanticException { JoinPredicateInfo jpi = null; JoinLeafPredicateInfo jlpi = null; List<JoinLeafPredicateInfo> equiLPIList = new ArrayList<JoinLeafPredicateInfo>(); List<JoinLeafPredicateInfo> nonEquiLPIList = new ArrayList<JoinLeafPredicateInfo>(); List<Set<Integer>> projsJoinKeys = new ArrayList<Set<Integer>>(); for (int i=0; i<inputs.size(); i++) { Set<Integer> projsJoinKeysInput = Sets.newHashSet(); projsJoinKeys.add(projsJoinKeysInput); } List<Set<Integer>> projsJoinKeysInJoinSchema = new ArrayList<Set<Integer>>(); for (int i=0; i<inputs.size(); i++) { Set<Integer> projsJoinKeysInJoinSchemaInput = Sets.newHashSet(); projsJoinKeysInJoinSchema.add(projsJoinKeysInJoinSchemaInput); } Map<Integer, List<JoinLeafPredicateInfo>> tmpMapOfProjIndxInJoinSchemaToLeafPInfo = new HashMap<Integer, List<JoinLeafPredicateInfo>>(); Map<Integer, ImmutableList<JoinLeafPredicateInfo>> mapOfProjIndxInJoinSchemaToLeafPInfo = new HashMap<Integer, ImmutableList<JoinLeafPredicateInfo>>(); List<JoinLeafPredicateInfo> tmpJLPILst = null; List<RexNode> conjuctiveElements; // 1. Decompose Join condition to a number of leaf predicates // (conjuctive elements) conjuctiveElements = RelOptUtil.conjunctions(predicate); // 2. Walk through leaf predicates building up JoinLeafPredicateInfo for (RexNode ce : conjuctiveElements) { // 2.1 Construct JoinLeafPredicateInfo jlpi = JoinLeafPredicateInfo.constructJoinLeafPredicateInfo(inputs, systemFieldList, ce); // 2.2 Classify leaf predicate as Equi vs Non Equi if (jlpi.comparisonType.equals(SqlKind.EQUALS)) { equiLPIList.add(jlpi); // 2.2.1 Maintain join keys (in child & Join Schema) // 2.2.2 Update Join Key to JoinLeafPredicateInfo map with keys for (int i=0; i<inputs.size(); i++) { projsJoinKeys.get(i).addAll(jlpi.getProjsJoinKeysInChildSchema(i)); projsJoinKeysInJoinSchema.get(i).addAll(jlpi.getProjsJoinKeysInJoinSchema(i)); for (Integer projIndx : jlpi.getProjsJoinKeysInJoinSchema(i)) { tmpJLPILst = tmpMapOfProjIndxInJoinSchemaToLeafPInfo.get(projIndx); if (tmpJLPILst == null) { tmpJLPILst = new ArrayList<JoinLeafPredicateInfo>(); } tmpJLPILst.add(jlpi); tmpMapOfProjIndxInJoinSchemaToLeafPInfo.put(projIndx, tmpJLPILst); } } } else { nonEquiLPIList.add(jlpi); } } // 3. Update Update Join Key to List<JoinLeafPredicateInfo> to use // ImmutableList for (Entry<Integer, List<JoinLeafPredicateInfo>> e : tmpMapOfProjIndxInJoinSchemaToLeafPInfo .entrySet()) { mapOfProjIndxInJoinSchemaToLeafPInfo.put(e.getKey(), ImmutableList.copyOf(e.getValue())); } // 4. Construct JoinPredicateInfo jpi = new JoinPredicateInfo(nonEquiLPIList, equiLPIList, projsJoinKeys, projsJoinKeysInJoinSchema, mapOfProjIndxInJoinSchemaToLeafPInfo); return jpi; } } /** * JoinLeafPredicateInfo represents leaf predicate in Join condition * (conjuctive lement).<br> * <p> * JoinLeafPredicateInfo:<br> * 1. Stores list of expressions from left and right child which is part of * equi join keys.<br> * 2. Stores set of projection indexes from left and right child which is part * of equi join keys; the indexes are both in child and Join node schema.<br> */ public static class JoinLeafPredicateInfo { private final SqlKind comparisonType; private final ImmutableList<ImmutableList<RexNode>> joinKeyExprs; private final ImmutableList<ImmutableSet<Integer>> projsJoinKeysInChildSchema; private final ImmutableList<ImmutableSet<Integer>> projsJoinKeysInJoinSchema; public JoinLeafPredicateInfo( SqlKind comparisonType, List<List<RexNode>> joinKeyExprs, List<Set<Integer>> projsJoinKeysInChildSchema, List<Set<Integer>> projsJoinKeysInJoinSchema) { this.comparisonType = comparisonType; ImmutableList.Builder<ImmutableList<RexNode>> joinKeyExprsBuilder = ImmutableList.builder(); for (int i=0; i<joinKeyExprs.size(); i++) { joinKeyExprsBuilder.add(ImmutableList.copyOf(joinKeyExprs.get(i))); } this.joinKeyExprs = joinKeyExprsBuilder.build(); ImmutableList.Builder<ImmutableSet<Integer>> projsJoinKeysInChildSchemaBuilder = ImmutableList.builder(); for (int i=0; i<projsJoinKeysInChildSchema.size(); i++) { projsJoinKeysInChildSchemaBuilder.add( ImmutableSet.copyOf(projsJoinKeysInChildSchema.get(i))); } this.projsJoinKeysInChildSchema = projsJoinKeysInChildSchemaBuilder.build(); ImmutableList.Builder<ImmutableSet<Integer>> projsJoinKeysInJoinSchemaBuilder = ImmutableList.builder(); for (int i=0; i<projsJoinKeysInJoinSchema.size(); i++) { projsJoinKeysInJoinSchemaBuilder.add( ImmutableSet.copyOf(projsJoinKeysInJoinSchema.get(i))); } this.projsJoinKeysInJoinSchema = projsJoinKeysInJoinSchemaBuilder.build(); } public List<RexNode> getJoinExprs(int input) { return this.joinKeyExprs.get(input); } public Set<Integer> getProjsFromLeftPartOfJoinKeysInChildSchema() { assert projsJoinKeysInChildSchema.size() == 2; return this.projsJoinKeysInChildSchema.get(0); } public Set<Integer> getProjsFromRightPartOfJoinKeysInChildSchema() { assert projsJoinKeysInChildSchema.size() == 2; return this.projsJoinKeysInChildSchema.get(1); } public Set<Integer> getProjsJoinKeysInChildSchema(int input) { return this.projsJoinKeysInChildSchema.get(input); } public Set<Integer> getProjsFromLeftPartOfJoinKeysInJoinSchema() { assert projsJoinKeysInJoinSchema.size() == 2; return this.projsJoinKeysInJoinSchema.get(0); } public Set<Integer> getProjsFromRightPartOfJoinKeysInJoinSchema() { assert projsJoinKeysInJoinSchema.size() == 2; return this.projsJoinKeysInJoinSchema.get(1); } public Set<Integer> getProjsJoinKeysInJoinSchema(int input) { return this.projsJoinKeysInJoinSchema.get(input); } // We create the join predicate info object. The object contains the join condition, // split accordingly. If the join condition is not part of the equi-join predicate, // the returned object will be typed as SQLKind.OTHER. private static JoinLeafPredicateInfo constructJoinLeafPredicateInfo(List<RelNode> inputs, List<RelDataTypeField> systemFieldList, RexNode pe) throws CalciteSemanticException { JoinLeafPredicateInfo jlpi = null; List<Integer> filterNulls = new ArrayList<Integer>(); List<List<RexNode>> joinExprs = new ArrayList<List<RexNode>>(); for (int i=0; i<inputs.size(); i++) { joinExprs.add(new ArrayList<RexNode>()); } // 1. Split leaf join predicate to expressions from left, right RexNode otherConditions = HiveRelOptUtil.splitHiveJoinCondition(systemFieldList, inputs, pe, joinExprs, filterNulls, null); if (otherConditions.isAlwaysTrue()) { // 2. Collect child projection indexes used List<Set<Integer>> projsJoinKeysInChildSchema = new ArrayList<Set<Integer>>(); for (int i=0; i<inputs.size(); i++) { ImmutableSet.Builder<Integer> projsFromInputJoinKeysInChildSchema = ImmutableSet.builder(); InputReferencedVisitor irvLeft = new InputReferencedVisitor(); irvLeft.apply(joinExprs.get(i)); projsFromInputJoinKeysInChildSchema.addAll(irvLeft.inputPosReferenced); projsJoinKeysInChildSchema.add(projsFromInputJoinKeysInChildSchema.build()); } // 3. Translate projection indexes to join schema, by adding offset. List<Set<Integer>> projsJoinKeysInJoinSchema = new ArrayList<Set<Integer>>(); // The offset of the first input does not need to change. projsJoinKeysInJoinSchema.add(projsJoinKeysInChildSchema.get(0)); for (int i=1; i<inputs.size(); i++) { int offSet = inputs.get(i-1).getRowType().getFieldCount(); ImmutableSet.Builder<Integer> projsFromInputJoinKeysInJoinSchema = ImmutableSet.builder(); for (Integer indx : projsJoinKeysInChildSchema.get(i)) { projsFromInputJoinKeysInJoinSchema.add(indx + offSet); } projsJoinKeysInJoinSchema.add(projsFromInputJoinKeysInJoinSchema.build()); } // 4. Construct JoinLeafPredicateInfo jlpi = new JoinLeafPredicateInfo(pe.getKind(), joinExprs, projsJoinKeysInChildSchema, projsJoinKeysInJoinSchema); } else { // 2. Construct JoinLeafPredicateInfo ImmutableBitSet refCols = InputFinder.bits(pe); int count = 0; for (int i=0; i<inputs.size(); i++) { final int length = inputs.get(i).getRowType().getFieldCount(); ImmutableBitSet inputRange = ImmutableBitSet.range(count, count + length); if (inputRange.contains(refCols)) { joinExprs.get(i).add(pe); } count += length; } jlpi = new JoinLeafPredicateInfo(SqlKind.OTHER, joinExprs, new ArrayList<Set<Integer>>(), new ArrayList<Set<Integer>>()); } return jlpi; } } public static boolean pureLimitRelNode(RelNode rel) { return limitRelNode(rel) && !orderRelNode(rel); } public static boolean pureOrderRelNode(RelNode rel) { return !limitRelNode(rel) && orderRelNode(rel); } public static boolean limitRelNode(RelNode rel) { if ((rel instanceof Sort) && ((Sort) rel).fetch != null) { return true; } return false; } public static boolean orderRelNode(RelNode rel) { if ((rel instanceof Sort) && !((Sort) rel).getCollation().getFieldCollations().isEmpty()) { return true; } return false; } /** * Get top level select starting from root. Assumption here is root can only * be Sort & Project. Also the top project should be at most 2 levels below * Sort; i.e Sort(Limit)-Sort(OB)-Select * * @param rootRel * @return */ public static Pair<RelNode, RelNode> getTopLevelSelect(final RelNode rootRel) { RelNode tmpRel = rootRel; RelNode parentOforiginalProjRel = rootRel; HiveProject originalProjRel = null; while (tmpRel != null) { if (tmpRel instanceof HiveProject) { originalProjRel = (HiveProject) tmpRel; break; } parentOforiginalProjRel = tmpRel; tmpRel = tmpRel.getInput(0); } return (new Pair<RelNode, RelNode>(parentOforiginalProjRel, originalProjRel)); } public static boolean isComparisonOp(RexCall call) { return call.getKind().belongsTo(SqlKind.COMPARISON); } public static final Function<RexNode, String> REX_STR_FN = new Function<RexNode, String>() { public String apply(RexNode r) { return r.toString(); } }; public static ImmutableList<RexNode> getPredsNotPushedAlready(RelNode inp, List<RexNode> predsToPushDown) { return getPredsNotPushedAlready(Sets.<String>newHashSet(), inp, predsToPushDown); } /** * Given a list of predicates to push down, this methods returns the set of predicates * that still need to be pushed. Predicates need to be pushed because 1) their String * representation is not included in input set of predicates to exclude, or 2) they are * already in the subtree rooted at the input node. * This method updates the set of predicates to exclude with the String representation * of the predicates in the output and in the subtree. * * @param predicatesToExclude String representation of predicates that should be excluded * @param inp root of the subtree * @param predsToPushDown candidate predicates to push down through the subtree * @return list of predicates to push down */ public static ImmutableList<RexNode> getPredsNotPushedAlready(Set<String> predicatesToExclude, RelNode inp, List<RexNode> predsToPushDown) { // Bail out if there is nothing to push if (predsToPushDown.isEmpty()) { return ImmutableList.of(); } // Build map to not convert multiple times, further remove already included predicates Map<String,RexNode> stringToRexNode = Maps.newLinkedHashMap(); for (RexNode r : predsToPushDown) { String rexNodeString = r.toString(); if (predicatesToExclude.add(rexNodeString)) { stringToRexNode.put(rexNodeString, r); } } if (stringToRexNode.isEmpty()) { return ImmutableList.of(); } // Finally exclude preds that are already in the subtree as given by the metadata provider // Note: this is the last step, trying to avoid the expensive call to the metadata provider // if possible Set<String> predicatesInSubtree = Sets.newHashSet(); for (RexNode pred : RelMetadataQuery.instance().getPulledUpPredicates(inp).pulledUpPredicates) { predicatesInSubtree.add(pred.toString()); predicatesInSubtree.addAll(Lists.transform(RelOptUtil.conjunctions(pred), REX_STR_FN)); } final ImmutableList.Builder<RexNode> newConjuncts = ImmutableList.builder(); for (Entry<String,RexNode> e : stringToRexNode.entrySet()) { if (predicatesInSubtree.add(e.getKey())) { newConjuncts.add(e.getValue()); } } predicatesToExclude.addAll(predicatesInSubtree); return newConjuncts.build(); } public static RexNode getTypeSafePred(RelOptCluster cluster, RexNode rex, RelDataType rType) { RexNode typeSafeRex = rex; if ((typeSafeRex instanceof RexCall) && HiveCalciteUtil.isComparisonOp((RexCall) typeSafeRex)) { RexBuilder rb = cluster.getRexBuilder(); List<RexNode> fixedPredElems = new ArrayList<RexNode>(); RelDataType commonType = cluster.getTypeFactory().leastRestrictive( RexUtil.types(((RexCall) rex).getOperands())); for (RexNode rn : ((RexCall) rex).getOperands()) { fixedPredElems.add(rb.ensureType(commonType, rn, true)); } typeSafeRex = rb.makeCall(((RexCall) typeSafeRex).getOperator(), fixedPredElems); } return typeSafeRex; } public static boolean isDeterministic(RexNode expr) { boolean deterministic = true; RexVisitor<Void> visitor = new RexVisitorImpl<Void>(true) { @Override public Void visitCall(org.apache.calcite.rex.RexCall call) { if (!call.getOperator().isDeterministic()) { throw new Util.FoundOne(call); } return super.visitCall(call); } }; try { expr.accept(visitor); } catch (Util.FoundOne e) { deterministic = false; } return deterministic; } private static class DeterMinisticFuncVisitorImpl extends RexVisitorImpl<Void> { protected DeterMinisticFuncVisitorImpl() { super(true); } @Override public Void visitCall(org.apache.calcite.rex.RexCall call) { if (!call.getOperator().isDeterministic()) { throw new Util.FoundOne(call); } return super.visitCall(call); } @Override public Void visitCorrelVariable(RexCorrelVariable correlVariable) { throw new Util.FoundOne(correlVariable); } @Override public Void visitLocalRef(RexLocalRef localRef) { throw new Util.FoundOne(localRef); } @Override public Void visitOver(RexOver over) { throw new Util.FoundOne(over); } @Override public Void visitDynamicParam(RexDynamicParam dynamicParam) { throw new Util.FoundOne(dynamicParam); } @Override public Void visitRangeRef(RexRangeRef rangeRef) { throw new Util.FoundOne(rangeRef); } @Override public Void visitFieldAccess(RexFieldAccess fieldAccess) { throw new Util.FoundOne(fieldAccess); } } public static boolean isDeterministicFuncOnLiterals(RexNode expr) { boolean deterministicFuncOnLiterals = true; RexVisitor<Void> visitor = new DeterMinisticFuncVisitorImpl() { @Override public Void visitInputRef(RexInputRef inputRef) { throw new Util.FoundOne(inputRef); } }; try { expr.accept(visitor); } catch (Util.FoundOne e) { deterministicFuncOnLiterals = false; } return deterministicFuncOnLiterals; } public List<RexNode> getDeterministicFuncWithSingleInputRef(List<RexNode> exprs, final Set<Integer> validInputRefs) { List<RexNode> determExprsWithSingleRef = new ArrayList<RexNode>(); for (RexNode e : exprs) { if (isDeterministicFuncWithSingleInputRef(e, validInputRefs)) { determExprsWithSingleRef.add(e); } } return determExprsWithSingleRef; } public static boolean isDeterministicFuncWithSingleInputRef(RexNode expr, final Set<Integer> validInputRefs) { boolean deterministicFuncWithSingleInputRef = true; RexVisitor<Void> visitor = new DeterMinisticFuncVisitorImpl() { Set<Integer> inputRefs = new HashSet<Integer>(); @Override public Void visitInputRef(RexInputRef inputRef) { if (validInputRefs.contains(inputRef.getIndex())) { inputRefs.add(inputRef.getIndex()); if (inputRefs.size() > 1) { throw new Util.FoundOne(inputRef); } } else { throw new Util.FoundOne(inputRef); } return null; } }; try { expr.accept(visitor); } catch (Util.FoundOne e) { deterministicFuncWithSingleInputRef = false; } return deterministicFuncWithSingleInputRef; } public static <T> ImmutableMap<Integer, T> getColInfoMap(List<T> hiveCols, int startIndx) { Builder<Integer, T> bldr = ImmutableMap.<Integer, T> builder(); int indx = startIndx; for (T ci : hiveCols) { bldr.put(indx, ci); indx++; } return bldr.build(); } public static ImmutableSet<Integer> shiftVColsSet(Set<Integer> hiveVCols, int shift) { ImmutableSet.Builder<Integer> bldr = ImmutableSet.<Integer> builder(); for (Integer pos : hiveVCols) { bldr.add(shift + pos); } return bldr.build(); } public static ImmutableMap<Integer, VirtualColumn> getVColsMap(List<VirtualColumn> hiveVCols, int startIndx) { Builder<Integer, VirtualColumn> bldr = ImmutableMap.<Integer, VirtualColumn> builder(); int indx = startIndx; for (VirtualColumn vc : hiveVCols) { bldr.put(indx, vc); indx++; } return bldr.build(); } public static ImmutableMap<String, Integer> getColNameIndxMap(List<FieldSchema> tableFields) { Builder<String, Integer> bldr = ImmutableMap.<String, Integer> builder(); int indx = 0; for (FieldSchema fs : tableFields) { bldr.put(fs.getName(), indx); indx++; } return bldr.build(); } public static ImmutableMap<String, Integer> getRowColNameIndxMap(List<RelDataTypeField> rowFields) { Builder<String, Integer> bldr = ImmutableMap.<String, Integer> builder(); int indx = 0; for (RelDataTypeField rdt : rowFields) { bldr.put(rdt.getName(), indx); indx++; } return bldr.build(); } public static ImmutableList<RexNode> getInputRef(List<Integer> inputRefs, RelNode inputRel) { ImmutableList.Builder<RexNode> bldr = ImmutableList.<RexNode> builder(); for (int i : inputRefs) { bldr.add(new RexInputRef(i, inputRel.getRowType().getFieldList().get(i).getType())); } return bldr.build(); } public static ExprNodeDesc getExprNode(Integer inputRefIndx, RelNode inputRel, ExprNodeConverter exprConv) { ExprNodeDesc exprNode = null; RexNode rexInputRef = new RexInputRef(inputRefIndx, inputRel.getRowType() .getFieldList().get(inputRefIndx).getType()); exprNode = rexInputRef.accept(exprConv); return exprNode; } public static List<ExprNodeDesc> getExprNodes(List<Integer> inputRefs, RelNode inputRel, String inputTabAlias) { List<ExprNodeDesc> exprNodes = new ArrayList<ExprNodeDesc>(); List<RexNode> rexInputRefs = getInputRef(inputRefs, inputRel); List<RexNode> exprs = inputRel.getChildExps(); // TODO: Change ExprNodeConverter to be independent of Partition Expr ExprNodeConverter exprConv = new ExprNodeConverter(inputTabAlias, inputRel.getRowType(), new HashSet<Integer>(), inputRel.getCluster().getTypeFactory()); for (int index = 0; index < rexInputRefs.size(); index++) { // The following check is only a guard against failures. // TODO: Knowing which expr is constant in GBY's aggregation function // arguments could be better done using Metadata provider of Calcite. //check the corresponding expression in exprs to see if it is literal if (exprs != null && index < exprs.size() && exprs.get(inputRefs.get(index)) instanceof RexLiteral) { //because rexInputRefs represent ref expr corresponding to value in inputRefs it is used to get // corresponding index ExprNodeDesc exprNodeDesc = exprConv.visitLiteral((RexLiteral) exprs.get(inputRefs.get(index))); exprNodes.add(exprNodeDesc); } else { RexNode iRef = rexInputRefs.get(index); exprNodes.add(iRef.accept(exprConv)); } } return exprNodes; } public static List<String> getFieldNames(List<Integer> inputRefs, RelNode inputRel) { List<String> fieldNames = new ArrayList<String>(); List<String> schemaNames = inputRel.getRowType().getFieldNames(); for (Integer iRef : inputRefs) { fieldNames.add(schemaNames.get(iRef)); } return fieldNames; } public static AggregateCall createSingleArgAggCall(String funcName, RelOptCluster cluster, PrimitiveTypeInfo typeInfo, Integer pos, RelDataType aggFnRetType) { ImmutableList.Builder<RelDataType> aggArgRelDTBldr = new ImmutableList.Builder<RelDataType>(); aggArgRelDTBldr.add(TypeConverter.convert(typeInfo, cluster.getTypeFactory())); SqlAggFunction aggFunction = SqlFunctionConverter.getCalciteAggFn(funcName, false, aggArgRelDTBldr.build(), aggFnRetType); List<Integer> argList = new ArrayList<Integer>(); argList.add(pos); return new AggregateCall(aggFunction, false, argList, aggFnRetType, null); } public static HiveTableFunctionScan createUDTFForSetOp(RelOptCluster cluster, RelNode input) throws SemanticException { RelTraitSet traitSet = TraitsUtil.getDefaultTraitSet(cluster); List<RexNode> originalInputRefs = Lists.transform(input.getRowType().getFieldList(), new Function<RelDataTypeField, RexNode>() { @Override public RexNode apply(RelDataTypeField input) { return new RexInputRef(input.getIndex(), input.getType()); } }); ImmutableList.Builder<RelDataType> argTypeBldr = ImmutableList.<RelDataType> builder(); for (int i = 0; i < originalInputRefs.size(); i++) { argTypeBldr.add(originalInputRefs.get(i).getType()); } RelDataType retType = input.getRowType(); String funcName = "replicate_rows"; FunctionInfo fi = FunctionRegistry.getFunctionInfo(funcName); SqlOperator calciteOp = SqlFunctionConverter.getCalciteOperator(funcName, fi.getGenericUDTF(), argTypeBldr.build(), retType); // Hive UDTF only has a single input List<RelNode> list = new ArrayList<>(); list.add(input); RexNode rexNode = cluster.getRexBuilder().makeCall(calciteOp, originalInputRefs); return HiveTableFunctionScan.create(cluster, traitSet, list, rexNode, null, retType, null); } // this will create a project which will project out the column in positions public static HiveProject createProjectWithoutColumn(RelNode input, Set<Integer> positions) throws CalciteSemanticException { List<RexNode> originalInputRefs = Lists.transform(input.getRowType().getFieldList(), new Function<RelDataTypeField, RexNode>() { @Override public RexNode apply(RelDataTypeField input) { return new RexInputRef(input.getIndex(), input.getType()); } }); List<RexNode> copyInputRefs = new ArrayList<>(); for (int i = 0; i < originalInputRefs.size(); i++) { if (!positions.contains(i)) { copyInputRefs.add(originalInputRefs.get(i)); } } return HiveProject.create(input, copyInputRefs, null); } /** * Walks over an expression and determines whether it is constant. */ public static class ConstantFinder implements RexVisitor<Boolean> { @Override public Boolean visitLiteral(RexLiteral literal) { return true; } @Override public Boolean visitInputRef(RexInputRef inputRef) { return false; } @Override public Boolean visitLocalRef(RexLocalRef localRef) { throw new RuntimeException("Not expected to be called."); } @Override public Boolean visitOver(RexOver over) { return false; } @Override public Boolean visitCorrelVariable(RexCorrelVariable correlVariable) { return false; } @Override public Boolean visitDynamicParam(RexDynamicParam dynamicParam) { return false; } @Override public Boolean visitCall(RexCall call) { // Constant if operator is deterministic and all operands are // constant. return call.getOperator().isDeterministic() && RexVisitorImpl.visitArrayAnd(this, call.getOperands()); } @Override public Boolean visitRangeRef(RexRangeRef rangeRef) { return false; } @Override public Boolean visitFieldAccess(RexFieldAccess fieldAccess) { // "<expr>.FIELD" is constant iff "<expr>" is constant. return fieldAccess.getReferenceExpr().accept(this); } @Override public Boolean visitSubQuery(RexSubQuery subQuery) { // it seems that it is not used by anything. return false; } @Override public Boolean visitPatternFieldRef(RexPatternFieldRef fieldRef) { return false; } } public static Set<Integer> getInputRefs(RexNode expr) { InputRefsCollector irefColl = new InputRefsCollector(true); expr.accept(irefColl); return irefColl.getInputRefSet(); } private static class InputRefsCollector extends RexVisitorImpl<Void> { private final Set<Integer> inputRefSet = new HashSet<Integer>(); private InputRefsCollector(boolean deep) { super(deep); } @Override public Void visitInputRef(RexInputRef inputRef) { inputRefSet.add(inputRef.getIndex()); return null; } public Set<Integer> getInputRefSet() { return inputRefSet; } } }