/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.sql.planner.plan.PlanNode; import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Optional; import java.util.function.Predicate; import static com.facebook.presto.sql.planner.optimizations.Predicates.alwaysTrue; import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; public class PlanNodeSearcher { public static PlanNodeSearcher searchFrom(PlanNode node) { return new PlanNodeSearcher(node); } private final PlanNode node; private Predicate<PlanNode> where = alwaysTrue(); private Predicate<PlanNode> skipOnly = alwaysTrue(); public PlanNodeSearcher(PlanNode node) { this.node = requireNonNull(node, "node is null"); } public PlanNodeSearcher where(Predicate<PlanNode> where) { this.where = requireNonNull(where, "where is null"); return this; } public PlanNodeSearcher skipOnlyWhen(Predicate<PlanNode> skipOnly) { this.skipOnly = requireNonNull(skipOnly, "skipOnly is null"); return this; } public <T extends PlanNode> Optional<T> findFirst() { return findFirstRecursive(node); } private <T extends PlanNode> Optional<T> findFirstRecursive(PlanNode node) { if (where.test(node)) { return Optional.of((T) node); } if (skipOnly.test(node)) { for (PlanNode source : node.getSources()) { Optional<T> found = findFirstRecursive(source); if (found.isPresent()) { return found; } } } return Optional.empty(); } public <T extends PlanNode> List<T> findAll() { ImmutableList.Builder<T> nodes = ImmutableList.builder(); findAllRecursive(node, nodes); return nodes.build(); } private <T extends PlanNode> void findAllRecursive(PlanNode node, ImmutableList.Builder<T> nodes) { if (where.test(node)) { nodes.add((T) node); } if (skipOnly.test(node)) { for (PlanNode source : node.getSources()) { findAllRecursive(source, nodes); } } } public PlanNode removeAll() { return removeAllRecursive(node); } private PlanNode removeAllRecursive(PlanNode node) { if (where.test(node)) { checkArgument( node.getSources().size() == 1, "Unable to remove plan node as it contains 0 or more than 1 children"); return node.getSources().get(0); } if (skipOnly.test(node)) { List<PlanNode> sources = node.getSources().stream() .map(source -> removeAllRecursive(source)) .collect(toImmutableList()); return replaceChildren(node, sources); } return node; } public PlanNode removeFirst() { return removeFirstRecursive(node); } private PlanNode removeFirstRecursive(PlanNode node) { if (where.test(node)) { checkArgument( node.getSources().size() == 1, "Unable to remove plan node as it contains 0 or more than 1 children"); return node.getSources().get(0); } if (skipOnly.test(node)) { List<PlanNode> sources = node.getSources(); if (sources.isEmpty()) { return node; } else if (sources.size() == 1) { return replaceChildren(node, ImmutableList.of(removeFirstRecursive(sources.get(0)))); } else { throw new IllegalArgumentException("Unable to remove first node when a node has multiple children, use removeAll instead"); } } return node; } public PlanNode replaceAll(PlanNode newPlanNode) { return replaceAllRecursive(node, newPlanNode); } private PlanNode replaceAllRecursive(PlanNode node, PlanNode nodeToReplace) { if (where.test(node)) { return nodeToReplace; } if (skipOnly.test(node)) { List<PlanNode> sources = node.getSources().stream() .map(source -> replaceAllRecursive(source, nodeToReplace)) .collect(toImmutableList()); return replaceChildren(node, sources); } return node; } public PlanNode replaceFirst(PlanNode newPlanNode) { return replaceFirstRecursive(node, newPlanNode); } private PlanNode replaceFirstRecursive(PlanNode node, PlanNode nodeToReplace) { if (where.test(node)) { return nodeToReplace; } List<PlanNode> sources = node.getSources(); if (sources.isEmpty()) { return node; } else if (sources.size() == 1) { return replaceChildren(node, ImmutableList.of(replaceFirstRecursive(node, sources.get(0)))); } else { throw new IllegalArgumentException("Unable to replace first node when a node has multiple children, use replaceAll instead"); } } public boolean matches() { return findFirst().isPresent(); } public int count() { return findAll().size(); } }