package com.taobao.tddl.optimizer.costbased; import java.util.List; import org.junit.Assert; import org.junit.Test; import com.taobao.tddl.optimizer.BaseOptimizerTest; import com.taobao.tddl.optimizer.core.ast.query.JoinNode; import com.taobao.tddl.optimizer.core.ast.query.QueryNode; import com.taobao.tddl.optimizer.core.ast.query.TableNode; import com.taobao.tddl.optimizer.core.expression.IFilter; import com.taobao.tddl.optimizer.costbased.pusher.FilterPusher; import com.taobao.tddl.optimizer.utils.FilterUtils; import com.taobao.tddl.optimizer.utils.OptimizerUtils; public class FilterPusherTest extends BaseOptimizerTest { @Test public void test_where中OR条件不可下推() { TableNode table1 = new TableNode("TABLE1"); TableNode table2 = new TableNode("TABLE2"); JoinNode join = table1.join(table2); join.query("(TABLE1.ID>5 OR TABLE2.ID<10) AND TABLE1.NAME = TABLE2.NAME"); join.build(); FilterPreProcessor.optimize(join, true); FilterPusher.optimize(join); Assert.assertEquals(null, join.getLeftNode().getWhereFilter()); Assert.assertEquals(null, join.getRightNode().getWhereFilter()); Assert.assertTrue(join.getJoinFilter().isEmpty()); } @Test public void test_where条件下推() { TableNode table1 = new TableNode("TABLE1"); TableNode table2 = new TableNode("TABLE2"); JoinNode join = table1.join(table2); join.query("TABLE1.ID>5 AND TABLE2.ID<10 AND TABLE1.NAME = TABLE2.NAME"); join.build(); FilterPusher.optimize(join); Assert.assertEquals("TABLE1.ID > 5", join.getLeftNode().getWhereFilter().toString()); Assert.assertEquals("TABLE2.ID < 10", join.getRightNode().getWhereFilter().toString()); Assert.assertEquals("TABLE1.NAME = TABLE2.NAME", join.getJoinFilter().get(0).toString()); } @Test public void test_where条件下推_join列存在函数不处理() { TableNode table1 = new TableNode("TABLE1"); TableNode table2 = new TableNode("TABLE2"); JoinNode join = table1.join(table2); join.query("TABLE1.ID>5 AND TABLE2.ID<10 AND TABLE1.NAME = TABLE2.NAME + 1"); join.build(); FilterPusher.optimize(join); Assert.assertEquals("TABLE1.ID > 5", join.getLeftNode().getWhereFilter().toString()); Assert.assertEquals("TABLE2.ID < 10", join.getRightNode().getWhereFilter().toString()); Assert.assertEquals("TABLE1.NAME = TABLE2.NAME + 1", join.getWhereFilter().toString());// 还是留在where中 } @Test public void test_where条件下推_条件推导下推() { TableNode table1 = new TableNode("TABLE1"); TableNode table2 = new TableNode("TABLE2"); JoinNode join = table1.join(table2); join.query("TABLE1.ID>5 AND TABLE2.ID<10 AND TABLE1.ID = TABLE2.ID"); join.build(); FilterPusher.optimize(join); Assert.assertEquals("(TABLE1.ID > 5 AND TABLE1.ID < 10)", join.getLeftNode().getWhereFilter().toString()); Assert.assertEquals("(TABLE2.ID < 10 AND TABLE2.ID > 5)", join.getRightNode().getWhereFilter().toString()); Assert.assertEquals("TABLE1.ID = TABLE2.ID", join.getJoinFilter().get(0).toString()); } @Test public void test_where条件下推_子查询() { TableNode table1 = new TableNode("TABLE1"); TableNode table2 = new TableNode("TABLE2"); JoinNode join = table1.join(table2); join.alias("S").select("TABLE1.ID,TABLE1.NAME,TABLE1.SCHOOL"); join.query("TABLE1.ID>5 AND TABLE2.ID<10 AND TABLE1.NAME = TABLE2.NAME"); join.build(); QueryNode query = new QueryNode(join); query.query("S.SCHOOL = 6"); query.build(); FilterPusher.optimize(query); Assert.assertEquals("(TABLE1.SCHOOL = 6 AND TABLE1.ID > 5)", join.getLeftNode().getWhereFilter().toString()); Assert.assertEquals("TABLE2.ID < 10", join.getRightNode().getWhereFilter().toString()); Assert.assertEquals("TABLE1.NAME = TABLE2.NAME", join.getJoinFilter().get(0).toString()); } @Test public void test_where条件下推_子查询_函数列不传递() { TableNode table1 = new TableNode("TABLE1"); TableNode table2 = new TableNode("TABLE2"); JoinNode join = table1.join(table2); join.alias("S").select("TABLE1.ID"); join.getColumnsSelected() .add(OptimizerUtils.createColumnFromString("CONCAT_WS(' ',TABLE1.NAME,TABLE1.SCHOOL) AS NAME")); join.query("TABLE1.ID>5 AND TABLE2.ID<10 AND TABLE1.NAME = TABLE2.NAME"); join.build(); QueryNode query = new QueryNode(join); query.query("S.NAME = 1"); query.build(); FilterPusher.optimize(query); Assert.assertEquals("TABLE1.ID > 5", join.getLeftNode().getWhereFilter().toString()); Assert.assertEquals("TABLE2.ID < 10", join.getRightNode().getWhereFilter().toString()); Assert.assertEquals("TABLE1.NAME = TABLE2.NAME", join.getJoinFilter().get(0).toString()); } @Test public void test_where条件下推_多级子查询_函数列不传递_字段列传递() { TableNode table1 = new TableNode("TABLE1"); TableNode table2 = new TableNode("TABLE2"); JoinNode join = table1.join(table2); join.alias("S").select("TABLE1.ID"); join.getColumnsSelected() .add(OptimizerUtils.createColumnFromString("CONCAT_WS(' ',TABLE1.NAME,TABLE1.SCHOOL) AS NAME")); join.query("TABLE1.ID>5 AND TABLE2.ID<10 AND TABLE1.NAME = TABLE2.NAME"); join.build(); QueryNode query = new QueryNode(join); query.alias("B").query("S.NAME = 1"); query.build(); QueryNode nextQuery = new QueryNode(query); nextQuery.query("B.ID = 6"); nextQuery.build(); FilterPusher.optimize(nextQuery); Assert.assertEquals("(TABLE1.ID = 6 AND TABLE1.ID > 5)", join.getLeftNode().getWhereFilter().toString()); Assert.assertEquals("TABLE2.ID < 10", join.getRightNode().getWhereFilter().toString()); Assert.assertEquals("TABLE1.NAME = TABLE2.NAME", join.getJoinFilter().get(0).toString()); } @Test public void test_where条件下推_多级join() { TableNode table1 = new TableNode("TABLE1"); TableNode table2 = new TableNode("TABLE2"); TableNode table3 = new TableNode("TABLE3"); JoinNode join = table1.join(table2); join.select("TABLE1.ID AS ID , TABLE1.NAME AS NAME"); join.query("TABLE1.ID>5 AND TABLE2.ID<10 AND TABLE1.NAME = TABLE2.NAME"); join.build(); JoinNode nextJoin = join.join(table3); nextJoin.query("TABLE1.NAME = 6 AND TABLE1.ID = TABLE3.ID"); nextJoin.build(); FilterPusher.optimize(nextJoin); Assert.assertEquals("(TABLE1.NAME = 6 AND TABLE1.ID > 5)", ((JoinNode) nextJoin.getLeftNode()).getLeftNode() .getWhereFilter() .toString()); Assert.assertEquals("(TABLE2.ID < 10 AND TABLE2.NAME = 6)", ((JoinNode) nextJoin.getLeftNode()).getRightNode() .getWhereFilter() .toString()); Assert.assertEquals("TABLE1.ID = TABLE3.ID", nextJoin.getJoinFilter().get(0).toString()); } @Test public void test_where条件下推_多级join_子查询() { TableNode table1 = new TableNode("TABLE1"); TableNode table2 = new TableNode("TABLE2"); JoinNode join = table1.join(table2); join.alias("S").select("TABLE1.ID AS ID , TABLE1.NAME AS NAME , TABLE2.SCHOOL AS SCHOOL"); join.query("TABLE1.ID>5 AND TABLE2.ID<10 AND TABLE1.NAME = TABLE2.NAME"); join.build(); QueryNode queryA = new QueryNode(join); queryA.alias("B").query("S.NAME = 2"); queryA.build(); QueryNode queryB = queryA.deepCopy(); queryB.alias("C").query("S.NAME = 3"); queryB.build(); JoinNode nextJoin = queryA.join(queryB); nextJoin.query("C.SCHOOL = 4 AND B.ID = C.ID"); nextJoin.build(); FilterPusher.optimize(nextJoin); Assert.assertEquals("B.ID = C.ID", nextJoin.getJoinFilter().get(0).toString()); Assert.assertEquals("(TABLE1.NAME = 2 AND TABLE1.ID > 5)", ((JoinNode) nextJoin.getLeftNode().getChild()).getLeftNode().getWhereFilter().toString()); Assert.assertEquals("(TABLE2.ID < 10 AND TABLE2.NAME = 2)", ((JoinNode) nextJoin.getLeftNode().getChild()).getRightNode().getWhereFilter().toString()); Assert.assertEquals("TABLE1.NAME = TABLE2.NAME", ((JoinNode) nextJoin.getLeftNode().getChild()).getJoinFilter() .get(0) .toString()); Assert.assertEquals("(TABLE1.NAME = 3 AND TABLE1.ID > 5)", ((JoinNode) nextJoin.getRightNode().getChild()).getLeftNode().getWhereFilter().toString()); Assert.assertEquals("(TABLE2.SCHOOL = 4 AND TABLE2.ID < 10 AND TABLE2.NAME = 3)", ((JoinNode) nextJoin.getRightNode().getChild()).getRightNode().getWhereFilter().toString()); Assert.assertEquals("TABLE1.NAME = TABLE2.NAME", ((JoinNode) nextJoin.getRightNode().getChild()).getJoinFilter().get(0).toString()); } @Test public void test_join条件下推() { TableNode table1 = new TableNode("TABLE1"); TableNode table2 = new TableNode("TABLE2"); JoinNode join = table1.join(table2).addJoinKeys("TABLE1.NAME", "TABLE2.NAME"); addOtherJoinFilter(join, "TABLE1.ID>5 AND TABLE2.ID<10"); join.build(); FilterPusher.optimize(join); Assert.assertEquals("TABLE1.ID > 5", join.getLeftNode().getOtherJoinOnFilter().toString()); Assert.assertEquals("TABLE2.ID < 10", join.getRightNode().getOtherJoinOnFilter().toString()); Assert.assertEquals("TABLE1.NAME = TABLE2.NAME", join.getJoinFilter().get(0).toString()); } @Test public void test_join条件下推_条件推导下推() { TableNode table1 = new TableNode("TABLE1"); TableNode table2 = new TableNode("TABLE2"); JoinNode join = table1.join(table2).addJoinKeys("TABLE1.ID", "TABLE2.ID"); addOtherJoinFilter(join, "TABLE1.ID>5 AND TABLE2.ID<10"); join.build(); FilterPusher.optimize(join); Assert.assertEquals("(TABLE1.ID > 5 AND TABLE1.ID < 10)", join.getLeftNode().getOtherJoinOnFilter().toString()); Assert.assertEquals("(TABLE2.ID < 10 AND TABLE2.ID > 5)", join.getRightNode().getOtherJoinOnFilter().toString()); Assert.assertEquals("TABLE1.ID = TABLE2.ID", join.getJoinFilter().get(0).toString()); } @Test public void test_join条件下推_多级join() { TableNode table1 = new TableNode("TABLE1"); TableNode table2 = new TableNode("TABLE2"); TableNode table3 = new TableNode("TABLE3"); JoinNode join = table1.join(table2).addJoinKeys("TABLE1.NAME", "TABLE2.NAME"); join.alias("A").select("TABLE1.ID AS ID , TABLE1.NAME AS NAME"); addOtherJoinFilter(join, "TABLE1.ID>5 AND TABLE2.ID<10"); join.build(); JoinNode nextJoin = join.join(table3).addJoinKeys("A.ID", "TABLE3.ID"); addOtherJoinFilter(nextJoin, "A.NAME = 6"); nextJoin.build(); FilterPusher.optimize(nextJoin); Assert.assertEquals("A.ID = TABLE3.ID", nextJoin.getJoinFilter().get(0).toString()); Assert.assertEquals("TABLE1.NAME = TABLE2.NAME", ((JoinNode) nextJoin.getLeftNode()).getJoinFilter() .get(0) .toString()); Assert.assertEquals("(TABLE1.NAME = 6 AND TABLE1.ID > 5)", ((JoinNode) nextJoin.getLeftNode()).getLeftNode() .getOtherJoinOnFilter() .toString()); Assert.assertEquals("(TABLE2.ID < 10 AND TABLE2.NAME = 6)", ((JoinNode) nextJoin.getLeftNode()).getRightNode() .getOtherJoinOnFilter() .toString()); } private void addOtherJoinFilter(JoinNode jn, String filter) { IFilter f = FilterUtils.createFilter(filter); List<List<IFilter>> DNFFilters = FilterUtils.toDNFNodesArray(FilterUtils.toDNFAndFlat(f)); jn.setOtherJoinOnFilter(FilterUtils.DNFToAndLogicTree(DNFFilters.get(0))); } }