/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.google.common.collect.ImmutableList;
import org.testng.annotations.Test;
import java.util.List;
import static org.testng.Assert.assertEquals;
public class TestMemo
{
private PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();
@Test
public void testInitialization()
throws Exception
{
PlanNode plan = node(node());
Memo memo = new Memo(idAllocator, plan);
assertEquals(memo.getGroupCount(), 2);
assertMatchesStructure(plan, memo.extract());
}
/*
From: X -> Y -> Z
To: X -> Y' -> Z'
*/
@Test
public void testReplaceSubtree()
throws Exception
{
PlanNode plan = node(node(node()));
Memo memo = new Memo(idAllocator, plan);
assertEquals(memo.getGroupCount(), 3);
// replace child of root node with subtree
PlanNode transformed = node(node());
memo.replace(getChildGroup(memo, memo.getRootGroup()), transformed, "rule");
assertEquals(memo.getGroupCount(), 3);
assertMatchesStructure(memo.extract(), node(plan.getId(), transformed));
}
/*
From: X -> Y -> Z -> W
To: X -> Y' -> Z' -> W
*/
@Test
public void testReplaceNonLeafSubtree()
throws Exception
{
PlanNode w = node();
PlanNode z = node(w);
PlanNode y = node(z);
PlanNode x = node(y);
Memo memo = new Memo(idAllocator, x);
assertEquals(memo.getGroupCount(), 4);
int yGroup = getChildGroup(memo, memo.getRootGroup());
int zGroup = getChildGroup(memo, yGroup);
PlanNode rewrittenW = memo.getNode(zGroup).getSources().get(0);
PlanNode newZ = node(rewrittenW);
PlanNode newY = node(newZ);
memo.replace(yGroup, newY, "rule");
assertEquals(memo.getGroupCount(), 4);
assertMatchesStructure(
memo.extract(),
node(x.getId(),
node(newY.getId(),
node(newZ.getId(),
node(w.getId())))));
}
/*
From: X -> Y -> Z
To: X -> Z
*/
@Test
public void testRemoveNode()
throws Exception
{
PlanNode z = node();
PlanNode y = node(z);
PlanNode x = node(y);
Memo memo = new Memo(idAllocator, x);
assertEquals(memo.getGroupCount(), 3);
int yGroup = getChildGroup(memo, memo.getRootGroup());
memo.replace(yGroup, memo.getNode(yGroup).getSources().get(0), "rule");
assertEquals(memo.getGroupCount(), 2);
assertMatchesStructure(
memo.extract(),
node(x.getId(),
node(z.getId())));
}
/*
From: X -> Z
To: X -> Y -> Z
*/
@Test
public void testInsertNode()
throws Exception
{
PlanNode z = node();
PlanNode x = node(z);
Memo memo = new Memo(idAllocator, x);
assertEquals(memo.getGroupCount(), 2);
int zGroup = getChildGroup(memo, memo.getRootGroup());
PlanNode y = node(memo.getNode(zGroup));
memo.replace(zGroup, y, "rule");
assertEquals(memo.getGroupCount(), 3);
assertMatchesStructure(
memo.extract(),
node(x.getId(),
node(y.getId(),
node(z.getId()))));
}
/*
From: X -> Y -> Z
To: X --> Y1' --> Z
\-> Y2' -/
*/
@Test
public void testMultipleReferences()
throws Exception
{
PlanNode z = node();
PlanNode y = node(z);
PlanNode x = node(y);
Memo memo = new Memo(idAllocator, x);
assertEquals(memo.getGroupCount(), 3);
int yGroup = getChildGroup(memo, memo.getRootGroup());
PlanNode rewrittenZ = memo.getNode(yGroup).getSources().get(0);
PlanNode y1 = node(rewrittenZ);
PlanNode y2 = node(rewrittenZ);
PlanNode newX = node(y1, y2);
memo.replace(memo.getRootGroup(), newX, "rule");
assertEquals(memo.getGroupCount(), 4);
assertMatchesStructure(
memo.extract(),
node(newX.getId(),
node(y1.getId(), node(z.getId())),
node(y2.getId(), node(z.getId()))));
}
private static void assertMatchesStructure(PlanNode actual, PlanNode expected)
{
assertEquals(actual.getClass(), expected.getClass());
assertEquals(actual.getId(), expected.getId());
assertEquals(actual.getSources().size(), expected.getSources().size());
for (int i = 0; i < actual.getSources().size(); i++) {
assertMatchesStructure(actual.getSources().get(i), expected.getSources().get(i));
}
}
private int getChildGroup(Memo memo, int group)
{
PlanNode node = memo.getNode(group);
GroupReference child = (GroupReference) node.getSources().get(0);
return child.getGroupId();
}
private GenericNode node(PlanNodeId id, PlanNode... children)
{
return new GenericNode(id, ImmutableList.copyOf(children));
}
private GenericNode node(PlanNode... children)
{
return node(idAllocator.getNextId(), children);
}
private static class GenericNode
extends PlanNode
{
private final List<PlanNode> sources;
public GenericNode(PlanNodeId id, List<PlanNode> sources)
{
super(id);
this.sources = ImmutableList.copyOf(sources);
}
@Override
public List<PlanNode> getSources()
{
return sources;
}
@Override
public List<Symbol> getOutputSymbols()
{
return ImmutableList.of();
}
@Override
public PlanNode replaceChildren(List<PlanNode> newChildren)
{
return new GenericNode(getId(), newChildren);
}
}
}