/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative;
import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.StatsRecorder;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.airlift.units.Duration;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import static com.facebook.presto.spi.StandardErrorCode.OPTIMIZER_TIMEOUT;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.lang.String.format;
public class IterativeOptimizer
implements PlanOptimizer
{
private final List<PlanOptimizer> legacyRules;
private final Set<Rule> rules;
private final StatsRecorder stats;
public IterativeOptimizer(StatsRecorder stats, Set<Rule> rules)
{
this(stats, ImmutableList.of(), rules);
}
public IterativeOptimizer(StatsRecorder stats, List<PlanOptimizer> legacyRules, Set<Rule> newRules)
{
this.legacyRules = ImmutableList.copyOf(legacyRules);
this.rules = ImmutableSet.copyOf(newRules);
this.stats = stats;
stats.registerAll(rules);
}
@Override
public PlanNode optimize(PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator)
{
// only disable new rules if we have legacy rules to fall back to
if (!SystemSessionProperties.isNewOptimizerEnabled(session) && !legacyRules.isEmpty()) {
for (PlanOptimizer optimizer : legacyRules) {
plan = optimizer.optimize(plan, session, symbolAllocator.getTypes(), symbolAllocator, idAllocator);
}
return plan;
}
Memo memo = new Memo(idAllocator, plan);
Lookup lookup = node -> {
if (node instanceof GroupReference) {
return memo.getNode(((GroupReference) node).getGroupId());
}
return node;
};
Duration timeout = SystemSessionProperties.getOptimizerTimeout(session);
exploreGroup(memo.getRootGroup(), new Context(memo, lookup, idAllocator, symbolAllocator, System.nanoTime(), timeout.toMillis(), session));
return memo.extract();
}
private boolean exploreGroup(int group, Context context)
{
// tracks whether this group or any children groups change as
// this method executes
boolean progress = exploreNode(group, context);
while (exploreChildren(group, context)) {
progress = true;
// if children changed, try current group again
// in case we can match additional rules
if (!exploreNode(group, context)) {
// no additional matches, so bail out
break;
}
}
return progress;
}
private boolean exploreNode(int group, Context context)
{
PlanNode node = context.getMemo().getNode(group);
boolean done = false;
boolean progress = false;
while (!done) {
if (isTimeLimitExhausted(context)) {
throw new PrestoException(OPTIMIZER_TIMEOUT, format("The optimizer exhausted the time limit of %d ms", context.getTimeoutInMilliseconds()));
}
done = true;
for (Rule rule : rules) {
Optional<PlanNode> transformed;
long duration;
try {
long start = System.nanoTime();
transformed = rule.apply(node, context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator(), context.getSession());
duration = System.nanoTime() - start;
}
catch (RuntimeException e) {
stats.recordFailure(rule);
throw e;
}
stats.record(rule, duration, transformed.isPresent());
if (transformed.isPresent()) {
node = context.getMemo().replace(group, transformed.get(), rule.getClass().getName());
done = false;
progress = true;
}
}
}
return progress;
}
private boolean isTimeLimitExhausted(Context context)
{
return ((System.nanoTime() - context.getStartTimeInNanos()) / 1_000_000) >= context.getTimeoutInMilliseconds();
}
private boolean exploreChildren(int group, Context context)
{
boolean progress = false;
PlanNode expression = context.getMemo().getNode(group);
for (PlanNode child : expression.getSources()) {
checkState(child instanceof GroupReference, "Expected child to be a group reference. Found: " + child.getClass().getName());
if (exploreGroup(((GroupReference) child).getGroupId(), context)) {
progress = true;
}
}
return progress;
}
private static class Context
{
private final Memo memo;
private final Lookup lookup;
private final PlanNodeIdAllocator idAllocator;
private final SymbolAllocator symbolAllocator;
private final long startTimeInNanos;
private final long timeoutInMilliseconds;
private final Session session;
public Context(
Memo memo,
Lookup lookup,
PlanNodeIdAllocator idAllocator,
SymbolAllocator symbolAllocator,
long startTimeInNanos,
long timeoutInMilliseconds,
Session session)
{
checkArgument(timeoutInMilliseconds >= 0, "Timeout has to be a non-negative number [milliseconds]");
this.memo = memo;
this.lookup = lookup;
this.idAllocator = idAllocator;
this.symbolAllocator = symbolAllocator;
this.startTimeInNanos = startTimeInNanos;
this.timeoutInMilliseconds = timeoutInMilliseconds;
this.session = session;
}
public Memo getMemo()
{
return memo;
}
public Lookup getLookup()
{
return lookup;
}
public PlanNodeIdAllocator getIdAllocator()
{
return idAllocator;
}
public SymbolAllocator getSymbolAllocator()
{
return symbolAllocator;
}
public long getStartTimeInNanos()
{
return startTimeInNanos;
}
public long getTimeoutInMilliseconds()
{
return timeoutInMilliseconds;
}
public Session getSession()
{
return session;
}
}
}