/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.plan;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
public abstract class PlanRewriter<C, P>
extends PlanVisitor<PlanRewriter.RewriteContext<C, P>, PlanRewriter.Result<P>>
{
public static <C, P> Result<P> rewriteWith(PlanRewriter<C, P> rewriter, PlanNode node)
{
return node.accept(rewriter, new RewriteContext<>(rewriter, null));
}
public static <C, P> Result<P> rewriteWith(PlanRewriter<C, P> rewriter, PlanNode node, C context)
{
return node.accept(rewriter, new RewriteContext<>(rewriter, context));
}
@Override
protected Result<P> visitPlan(PlanNode node, RewriteContext<C, P> context)
{
return context.defaultRewrite(node, context.get());
}
public static class Result<P>
{
private final PlanNode planNode;
private final P payload;
public Result(PlanNode planNode, @Nullable P payload)
{
this.planNode = requireNonNull(planNode, "planNode is null");
this.payload = payload;
}
public PlanNode getPlanNode()
{
return planNode;
}
public P getPayload()
{
return payload;
}
}
public static class RewriteContext<C, P>
{
private final C userContext;
private final PlanRewriter<C, P> nodeRewriter;
private RewriteContext(PlanRewriter<C, P> nodeRewriter, @Nullable C userContext)
{
this.nodeRewriter = requireNonNull(nodeRewriter, "nodeRewriter is null");
this.userContext = userContext;
}
public C get()
{
return userContext;
}
/**
* Invoke the rewrite logic recursively on children of the given node and swap it
* out with an identical copy with the rewritten children. The final payload will
* be null.
*/
public Result<P> defaultRewrite(PlanNode node)
{
return defaultRewrite(node, null, payloads -> null);
}
/**
* Invoke the rewrite logic recursively on children of the given node and swap it
* out with an identical copy with the rewritten children. The final payload will
* be null.
*/
public Result<P> defaultRewrite(PlanNode node, C context)
{
return defaultRewrite(node, context, payloads -> null);
}
/**
* Invoke the rewrite logic recursively on children of the given node and swap it
* out with an identical copy with the rewritten children. The payloadCombiner is used
* to produce the final payload given the respective payloads of the children.
*/
public Result<P> defaultRewrite(PlanNode node, Function<List<P>, P> payloadCombiner)
{
return defaultRewrite(node, null, payloadCombiner);
}
/**
* Invoke the rewrite logic recursively on children of the given node and swap it
* out with an identical copy with the rewritten children. The payloadCombiner is used
* to produce the final payload given the respective payloads of the children.
*/
public Result<P> defaultRewrite(PlanNode node, C context, Function<List<P>, P> payloadCombiner)
{
List<PlanNode> children = new ArrayList<>(node.getSources().size());
List<P> payloads = new ArrayList<>(node.getSources().size());
for (PlanNode source : node.getSources()) {
Result<P> result = rewrite(source, context);
children.add(result.getPlanNode());
payloads.add(result.getPayload());
}
return new Result<>(ChildReplacer.replaceChildren(node, children), payloadCombiner.apply(payloads));
}
/**
* This method is meant for invoking the rewrite logic on children while processing a node.
*/
public Result<P> rewrite(PlanNode node, C userContext)
{
Result<P> result = node.accept(nodeRewriter, new RewriteContext<>(nodeRewriter, userContext));
requireNonNull(result, format("nodeRewriter returned null for %s", node.getClass().getName()));
return result;
}
/**
* This method is meant for invoking the rewrite logic on children while processing a node.
*/
public Result<P> rewrite(PlanNode node)
{
return rewrite(node, null);
}
}
}