/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.hive.ql.parse; import java.util.ArrayList; import java.util.List; import junit.framework.Assert; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.QueryState; import org.apache.hadoop.hive.ql.session.SessionState; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; public class TestQBJoinTreeApplyPredicate { static QueryState queryState; static HiveConf conf; SemanticAnalyzer sA; @BeforeClass public static void initialize() { queryState = new QueryState.Builder().withHiveConf(new HiveConf(SemanticAnalyzer.class)).build(); conf = queryState.getConf(); SessionState.start(conf); } @Before public void setup() throws SemanticException { sA = new CalcitePlanner(queryState); } static ASTNode constructIdentifier(String nm) { return (ASTNode) ParseDriver.adaptor.create(HiveParser.Identifier, nm); } static ASTNode constructTabRef(String tblNm) { ASTNode table = (ASTNode) ParseDriver.adaptor.create(HiveParser.TOK_TABLE_OR_COL, "TOK_TABLE_OR_COL"); ASTNode id = constructIdentifier(tblNm); table.addChild(id); return table; } static ASTNode constructColRef(String tblNm, String colNm) { ASTNode table = constructTabRef(tblNm); ASTNode col = constructIdentifier(colNm); ASTNode dot = (ASTNode) ParseDriver.adaptor.create(HiveParser.DOT, "."); dot.addChild(table); dot.addChild(col); return dot; } static ASTNode constructEqualityCond(String lTbl, String lCol, String rTbl, String rCol) { ASTNode lRef = constructColRef(lTbl, lCol); ASTNode rRef = constructColRef(rTbl, rCol); ASTNode eq = (ASTNode) ParseDriver.adaptor.create(HiveParser.EQUAL, "="); eq.addChild(lRef); eq.addChild(rRef); return eq; } QBJoinTree createJoinTree(JoinType type, String leftAlias, QBJoinTree leftTree, String rightAlias) { QBJoinTree jT = new QBJoinTree(); JoinCond[] condn = new JoinCond[1]; condn[0] = new JoinCond(0, 1, type); if ( leftTree == null ) { jT.setLeftAlias(leftAlias); String[] leftAliases = new String[1]; leftAliases[0] = leftAlias; jT.setLeftAliases(leftAliases); } else { jT.setJoinSrc(leftTree); String[] leftChildAliases = leftTree.getLeftAliases(); String leftAliases[] = new String[leftChildAliases.length + 1]; for (int i = 0; i < leftChildAliases.length; i++) { leftAliases[i] = leftChildAliases[i]; } leftAliases[leftChildAliases.length] = leftTree.getRightAliases()[0]; jT.setLeftAliases(leftAliases); } String[] rightAliases = new String[1]; rightAliases[0] = rightAlias; jT.setRightAliases(rightAliases); String[] children = new String[2]; children[0] = leftAlias; children[1] = rightAlias; jT.setBaseSrc(children); ArrayList<ArrayList<ASTNode>> expressions = new ArrayList<ArrayList<ASTNode>>(); expressions.add(new ArrayList<ASTNode>()); expressions.add(new ArrayList<ASTNode>()); jT.setExpressions(expressions); ArrayList<Boolean> nullsafes = new ArrayList<Boolean>(); jT.setNullSafes(nullsafes); ArrayList<ArrayList<ASTNode>> filters = new ArrayList<ArrayList<ASTNode>>(); filters.add(new ArrayList<ASTNode>()); filters.add(new ArrayList<ASTNode>()); jT.setFilters(filters); jT.setFilterMap(new int[2][]); ArrayList<ArrayList<ASTNode>> filtersForPushing = new ArrayList<ArrayList<ASTNode>>(); filtersForPushing.add(new ArrayList<ASTNode>()); filtersForPushing.add(new ArrayList<ASTNode>()); jT.setFiltersForPushing(filtersForPushing); return jT; } ASTNode applyEqPredicate(QBJoinTree jT, String lTbl, String lCol, String rTbl, String rCol) throws SemanticException { ASTNode joinCond = constructEqualityCond(lTbl, lCol, rTbl, rCol); ASTNode leftCondn = (ASTNode) joinCond.getChild(0); ASTNode rightCondn = (ASTNode) joinCond.getChild(1); List<String> leftSrc = new ArrayList<String>(); ArrayList<String> leftCondAl1 = new ArrayList<String>(); ArrayList<String> leftCondAl2 = new ArrayList<String>(); ArrayList<String> rightCondAl1 = new ArrayList<String>(); ArrayList<String> rightCondAl2 = new ArrayList<String>(); sA.parseJoinCondPopulateAlias(jT, leftCondn, leftCondAl1, leftCondAl2, null, null); sA.parseJoinCondPopulateAlias(jT, rightCondn, rightCondAl1, rightCondAl2, null, null); sA.applyEqualityPredicateToQBJoinTree(jT, JoinType.INNER, leftSrc, joinCond, leftCondn, rightCondn, leftCondAl1, leftCondAl2, rightCondAl1, rightCondAl2); return joinCond; } @Test public void testSimpleCondn() throws SemanticException { QBJoinTree jT = createJoinTree(JoinType.INNER, "a", null, "b"); ASTNode joinCond = applyEqPredicate(jT, "a", "x", "b", "y"); Assert.assertEquals(jT.getExpressions().get(0).get(0), joinCond.getChild(0)); Assert.assertEquals(jT.getExpressions().get(1).get(0), joinCond.getChild(1)); } @Test public void test3WayJoin() throws SemanticException { QBJoinTree jT1 = createJoinTree(JoinType.INNER, "a", null, "b"); QBJoinTree jT = createJoinTree(JoinType.INNER, "b", jT1, "c"); ASTNode joinCond1 = applyEqPredicate(jT, "a", "x", "b", "y"); ASTNode joinCond2 = applyEqPredicate(jT, "b", "y", "c", "z"); Assert.assertEquals(jT1.getExpressions().get(0).get(0), joinCond1.getChild(0)); Assert.assertEquals(jT1.getExpressions().get(1).get(0), joinCond1.getChild(1)); Assert.assertEquals(jT.getExpressions().get(0).get(0), joinCond2.getChild(0)); Assert.assertEquals(jT.getExpressions().get(1).get(0), joinCond2.getChild(1)); } @Test public void test3WayJoinSwitched() throws SemanticException { QBJoinTree jT1 = createJoinTree(JoinType.INNER, "a", null, "b"); QBJoinTree jT = createJoinTree(JoinType.INNER, "b", jT1, "c"); ASTNode joinCond1 = applyEqPredicate(jT, "b", "y", "a", "x"); ASTNode joinCond2 = applyEqPredicate(jT, "b", "y", "c", "z"); Assert.assertEquals(jT1.getExpressions().get(0).get(0), joinCond1.getChild(1)); Assert.assertEquals(jT1.getExpressions().get(1).get(0), joinCond1.getChild(0)); Assert.assertEquals(jT.getExpressions().get(0).get(0), joinCond2.getChild(0)); Assert.assertEquals(jT.getExpressions().get(1).get(0), joinCond2.getChild(1)); } @Test public void test4WayJoin() throws SemanticException { QBJoinTree jT1 = createJoinTree(JoinType.INNER, "a", null, "b"); QBJoinTree jT2 = createJoinTree(JoinType.INNER, "b", jT1, "c"); QBJoinTree jT = createJoinTree(JoinType.INNER, "c", jT2, "d"); ASTNode joinCond1 = applyEqPredicate(jT, "a", "x", "b", "y"); ASTNode joinCond2 = applyEqPredicate(jT, "b", "y", "c", "z"); ASTNode joinCond3 = applyEqPredicate(jT, "a", "x", "c", "z"); Assert.assertEquals(jT1.getExpressions().get(0).get(0), joinCond1.getChild(0)); Assert.assertEquals(jT1.getExpressions().get(1).get(0), joinCond1.getChild(1)); Assert.assertEquals(jT2.getExpressions().get(0).get(0), joinCond2.getChild(0)); Assert.assertEquals(jT2.getExpressions().get(1).get(0), joinCond2.getChild(1)); Assert.assertEquals(jT2.getExpressions().get(0).get(1), joinCond3.getChild(0)); Assert.assertEquals(jT2.getExpressions().get(1).get(1), joinCond3.getChild(1)); } @Test public void test4WayJoinSwitched() throws SemanticException { QBJoinTree jT1 = createJoinTree(JoinType.INNER, "a", null, "b"); QBJoinTree jT2 = createJoinTree(JoinType.INNER, "b", jT1, "c"); QBJoinTree jT = createJoinTree(JoinType.INNER, "c", jT2, "d"); ASTNode joinCond1 = applyEqPredicate(jT, "b", "y", "a", "x"); ASTNode joinCond2 = applyEqPredicate(jT, "b", "y", "c", "z"); ASTNode joinCond3 = applyEqPredicate(jT, "c", "z", "a", "x"); Assert.assertEquals(jT1.getExpressions().get(0).get(0), joinCond1.getChild(1)); Assert.assertEquals(jT1.getExpressions().get(1).get(0), joinCond1.getChild(0)); Assert.assertEquals(jT2.getExpressions().get(0).get(0), joinCond2.getChild(0)); Assert.assertEquals(jT2.getExpressions().get(1).get(0), joinCond2.getChild(1)); Assert.assertEquals(jT2.getExpressions().get(0).get(1), joinCond3.getChild(1)); Assert.assertEquals(jT2.getExpressions().get(1).get(1), joinCond3.getChild(0)); } }