/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.plan.PlanNode;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
/**
* Stores a plan in a form that's efficient to mutate locally (i.e. without
* having to do full ancestor tree rewrites due to plan nodes being immutable).
*
* Each node in a plan is placed in a group, and it's children are replaced with
* symbolic references to the corresponding groups.
*
* For example, a plan like:
* A -> B -> C -> D
* \> E -> F
*
* would be stored as:
*
* root: G0
*
* G0 : { A -> G1 }
* G1 : { B -> [G2, G3] }
* G2 : { C -> G4 }
* G3 : { E -> G5 }
* G4 : { D }
* G5 : { F }
*
* Groups are reference-counted, and groups that become unreachable from the root
* due to mutations in a subtree get garbage-collected.
*/
public class Memo
{
private final PlanNodeIdAllocator idAllocator;
private final int rootGroup;
private final Map<Integer, PlanNode> membership = new HashMap<>();
private final Map<Integer, Integer> referenceCounts = new HashMap<>();
private int nextGroupId;
public Memo(PlanNodeIdAllocator idAllocator, PlanNode plan)
{
this.idAllocator = idAllocator;
rootGroup = insertRecursive(plan);
referenceCounts.put(rootGroup, 1);
}
public int getRootGroup()
{
return rootGroup;
}
public PlanNode getNode(int group)
{
checkArgument(membership.containsKey(group), "Invalid group: %s", group);
return membership.get(group);
}
public PlanNode extract()
{
return extract(getNode(rootGroup));
}
private PlanNode extract(PlanNode node)
{
if (node instanceof GroupReference) {
return extract(membership.get(((GroupReference) node).getGroupId()));
}
List<PlanNode> children = node.getSources().stream()
.map(this::extract)
.collect(Collectors.toList());
return node.replaceChildren(children);
}
public PlanNode replace(int group, PlanNode node, String reason)
{
PlanNode old = membership.get(group);
checkArgument(new HashSet<>(old.getOutputSymbols()).equals(new HashSet<>(node.getOutputSymbols())),
"%s: transformed expression doesn't produce same outputs: %s vs %s",
reason,
old.getOutputSymbols(),
node.getOutputSymbols());
if (node instanceof GroupReference) {
node = getNode(((GroupReference) node).getGroupId());
}
else {
node = insertChildrenAndRewrite(node);
}
incrementReferenceCounts(node);
membership.put(group, node);
decrementReferenceCounts(old);
return node;
}
private void incrementReferenceCounts(PlanNode node)
{
Set<Integer> references = getAllReferences(node);
for (int group : references) {
referenceCounts.compute(group, (g, count) -> count + 1);
}
}
private void decrementReferenceCounts(PlanNode node)
{
Set<Integer> references = getAllReferences(node);
for (int group : references) {
int newCount = referenceCounts.compute(group, (g, count) -> count - 1);
checkState(newCount >= 0, "Reference count became negative");
if (newCount == 0) {
PlanNode child = membership.get(group);
deleteGroup(group);
decrementReferenceCounts(child);
}
}
}
private Set<Integer> getAllReferences(PlanNode node)
{
return node.getSources().stream()
.map(GroupReference.class::cast)
.map(GroupReference::getGroupId)
.collect(Collectors.toSet());
}
private void deleteGroup(int group)
{
membership.remove(group);
referenceCounts.remove(group);
}
private PlanNode insertChildrenAndRewrite(PlanNode node)
{
return node.replaceChildren(
node.getSources().stream()
.map(child -> new GroupReference(
idAllocator.getNextId(),
insertRecursive(child),
child.getOutputSymbols()))
.collect(Collectors.toList()));
}
private int insertRecursive(PlanNode node)
{
if (node instanceof GroupReference) {
return ((GroupReference) node).getGroupId();
}
int group = nextGroupId();
PlanNode rewritten = insertChildrenAndRewrite(node);
membership.put(group, rewritten);
referenceCounts.put(group, 0);
incrementReferenceCounts(rewritten);
return group;
}
private int nextGroupId()
{
return nextGroupId++;
}
public int getGroupCount()
{
return membership.size();
}
}