/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.optimizations;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.optimizations.joins.JoinGraph;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.facebook.presto.sql.tree.ArithmeticUnaryExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableList;
import org.testng.annotations.Test;
import java.util.Arrays;
import java.util.Optional;
import static com.facebook.presto.sql.planner.optimizations.EliminateCrossJoins.isOriginalOrder;
import static com.facebook.presto.sql.tree.ArithmeticUnaryExpression.Sign.MINUS;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static org.testng.Assert.assertEquals;
import static org.testng.AssertJUnit.assertFalse;
import static org.testng.AssertJUnit.assertTrue;
@Test(singleThreaded = true)
public class TestEliminateCrossJoins
{
PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();
@Test
public void testIsOriginalOrder()
{
assertTrue(isOriginalOrder(ImmutableList.of(0, 1, 2, 3, 4)));
assertFalse(isOriginalOrder(ImmutableList.of(0, 2, 1, 3, 4)));
}
@Test
public void testJoinOrder()
{
PlanNode plan =
join(
join(
values(symbol("a")),
values(symbol("b"))),
values(symbol("c")),
symbol("a"), symbol("c"),
symbol("c"), symbol("b"));
JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan));
assertEquals(
EliminateCrossJoins.getJoinOrder(joinGraph),
ImmutableList.of(0, 2, 1));
}
@Test
public void testJoinOrderWithRealCrossJoin()
{
PlanNode leftPlan =
join(
join(
values(symbol("a")),
values(symbol("b"))),
values(symbol("c")),
symbol("a"), symbol("c"),
symbol("c"), symbol("b"));
PlanNode rightPlan =
join(
join(
values(symbol("x")),
values(symbol("y"))),
values(symbol("z")),
symbol("x"), symbol("z"),
symbol("z"), symbol("y"));
PlanNode plan = join(leftPlan, rightPlan);
JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan));
assertEquals(
EliminateCrossJoins.getJoinOrder(joinGraph),
ImmutableList.of(0, 2, 1, 3, 5, 4));
}
@Test
public void testJoinOrderWithMultipleEdgesBetweenNodes()
{
PlanNode plan =
join(
join(
values(symbol("a")),
values(symbol("b1"), symbol("b2"))),
values(symbol("c1"), symbol("c2")),
symbol("a"), symbol("c1"),
symbol("c1"), symbol("b1"),
symbol("c2"), symbol("b2"));
JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan));
assertEquals(
EliminateCrossJoins.getJoinOrder(joinGraph),
ImmutableList.of(0, 2, 1));
}
@Test
public void testDonNotChangeOrderWithoutCrossJoin()
{
PlanNode plan =
join(
join(
values(symbol("a")),
values(symbol("b")),
symbol("a"), symbol("b")),
values(symbol("c")),
symbol("c"), symbol("b"));
JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan));
assertEquals(
EliminateCrossJoins.getJoinOrder(joinGraph),
ImmutableList.of(0, 1, 2));
}
@Test
public void testDoNotReorderCrossJoins()
{
PlanNode plan =
join(
join(
values(symbol("a")),
values(symbol("b"))),
values(symbol("c")),
symbol("c"), symbol("b"));
JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan));
assertEquals(
EliminateCrossJoins.getJoinOrder(joinGraph),
ImmutableList.of(0, 1, 2));
}
@Test
public void testGiveUpOnNonIdentityProjections()
{
PlanNode plan =
join(
project(
join(
values(symbol("a1")),
values(symbol("b"))),
symbol("a2"),
new ArithmeticUnaryExpression(MINUS, new SymbolReference("a1"))),
values(symbol("c")),
symbol("a2"), symbol("c"),
symbol("c"), symbol("b"));
assertEquals(JoinGraph.buildFrom(plan).size(), 2);
}
private PlanNode project(PlanNode source, String symbol, Expression expression)
{
return new ProjectNode(
idAllocator.getNextId(),
source,
Assignments.of(new Symbol(symbol), expression));
}
private String symbol(String name)
{
return name;
}
private JoinNode join(PlanNode left, PlanNode right, String... symbols)
{
checkArgument(symbols.length % 2 == 0);
ImmutableList.Builder<JoinNode.EquiJoinClause> criteria = ImmutableList.builder();
for (int i = 0; i < symbols.length; i += 2) {
criteria.add(new JoinNode.EquiJoinClause(new Symbol(symbols[i]), new Symbol(symbols[i + 1])));
}
return new JoinNode(
idAllocator.getNextId(),
JoinNode.Type.INNER,
left,
right,
criteria.build(),
ImmutableList.<Symbol>builder()
.addAll(left.getOutputSymbols())
.addAll(right.getOutputSymbols())
.build(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty());
}
private ValuesNode values(String... symbols)
{
return new ValuesNode(
idAllocator.getNextId(),
Arrays.stream(symbols).map(Symbol::new).collect(toImmutableList()),
ImmutableList.of());
}
}