/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql; import com.facebook.presto.sql.planner.DependencyExtractor; import com.facebook.presto.sql.planner.DeterminismEvaluator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.IsNullPredicate; import com.facebook.presto.sql.tree.LambdaExpression; import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import java.util.ArrayDeque; import java.util.Arrays; import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Queue; import java.util.Set; import java.util.function.Function; import java.util.function.Predicate; import static com.facebook.presto.sql.tree.BooleanLiteral.FALSE_LITERAL; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; import static com.facebook.presto.sql.tree.ComparisonExpressionType.IS_DISTINCT_FROM; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; public final class ExpressionUtils { private ExpressionUtils() {} public static List<Expression> extractConjuncts(Expression expression) { return extractPredicates(LogicalBinaryExpression.Type.AND, expression); } public static List<Expression> extractDisjuncts(Expression expression) { return extractPredicates(LogicalBinaryExpression.Type.OR, expression); } public static List<Expression> extractPredicates(LogicalBinaryExpression expression) { return extractPredicates(expression.getType(), expression); } public static List<Expression> extractPredicates(LogicalBinaryExpression.Type type, Expression expression) { if (expression instanceof LogicalBinaryExpression && ((LogicalBinaryExpression) expression).getType() == type) { LogicalBinaryExpression logicalBinaryExpression = (LogicalBinaryExpression) expression; return ImmutableList.<Expression>builder() .addAll(extractPredicates(type, logicalBinaryExpression.getLeft())) .addAll(extractPredicates(type, logicalBinaryExpression.getRight())) .build(); } return ImmutableList.of(expression); } public static Expression and(Expression... expressions) { return and(Arrays.asList(expressions)); } public static Expression and(Collection<Expression> expressions) { return binaryExpression(LogicalBinaryExpression.Type.AND, expressions); } public static Expression or(Expression... expressions) { return or(Arrays.asList(expressions)); } public static Expression or(Collection<Expression> expressions) { return binaryExpression(LogicalBinaryExpression.Type.OR, expressions); } public static Expression binaryExpression(LogicalBinaryExpression.Type type, Collection<Expression> expressions) { requireNonNull(type, "type is null"); requireNonNull(expressions, "expressions is null"); Preconditions.checkArgument(!expressions.isEmpty(), "expressions is empty"); // Build balanced tree for efficient recursive processing that // preserves the evaluation order of the input expressions. // // The tree is built bottom up by combining pairs of elements into // binary AND expressions. // // Example: // // Initial state: // a b c d e // // First iteration: // // /\ /\ e // a b c d // // Second iteration: // // / \ e // /\ /\ // a b c d // // // Last iteration: // // / \ // / \ e // /\ /\ // a b c d Queue<Expression> queue = new ArrayDeque<>(expressions); while (queue.size() > 1) { Queue<Expression> buffer = new ArrayDeque<>(); // combine pairs of elements while (queue.size() >= 2) { buffer.add(new LogicalBinaryExpression(type, queue.remove(), queue.remove())); } // if there's and odd number of elements, just append the last one if (!queue.isEmpty()) { buffer.add(queue.remove()); } // continue processing the pairs that were just built queue = buffer; } return queue.remove(); } public static Expression combinePredicates(LogicalBinaryExpression.Type type, Expression... expressions) { return combinePredicates(type, Arrays.asList(expressions)); } public static Expression combinePredicates(LogicalBinaryExpression.Type type, Collection<Expression> expressions) { if (type == LogicalBinaryExpression.Type.AND) { return combineConjuncts(expressions); } return combineDisjuncts(expressions); } public static Expression combineConjuncts(Expression... expressions) { return combineConjuncts(Arrays.asList(expressions)); } public static Expression combineConjuncts(Collection<Expression> expressions) { return combineConjunctsWithDefault(expressions, TRUE_LITERAL); } public static Expression combineConjunctsWithDefault(Collection<Expression> expressions, Expression emptyDefault) { requireNonNull(expressions, "expressions is null"); List<Expression> conjuncts = expressions.stream() .flatMap(e -> ExpressionUtils.extractConjuncts(e).stream()) .filter(e -> !e.equals(TRUE_LITERAL)) .collect(toList()); conjuncts = removeDuplicates(conjuncts); if (conjuncts.contains(FALSE_LITERAL)) { return FALSE_LITERAL; } return conjuncts.isEmpty() ? emptyDefault : and(conjuncts); } public static Expression combineDisjuncts(Expression... expressions) { return combineDisjuncts(Arrays.asList(expressions)); } public static Expression combineDisjuncts(Collection<Expression> expressions) { return combineDisjunctsWithDefault(expressions, FALSE_LITERAL); } public static Expression combineDisjunctsWithDefault(Collection<Expression> expressions, Expression emptyDefault) { requireNonNull(expressions, "expressions is null"); List<Expression> disjuncts = expressions.stream() .flatMap(e -> ExpressionUtils.extractDisjuncts(e).stream()) .filter(e -> !e.equals(FALSE_LITERAL)) .collect(toList()); disjuncts = removeDuplicates(disjuncts); if (disjuncts.contains(TRUE_LITERAL)) { return TRUE_LITERAL; } return disjuncts.isEmpty() ? emptyDefault : or(disjuncts); } public static Expression stripNonDeterministicConjuncts(Expression expression) { List<Expression> conjuncts = extractConjuncts(expression).stream() .filter(DeterminismEvaluator::isDeterministic) .collect(toList()); return combineConjuncts(conjuncts); } public static Expression stripDeterministicConjuncts(Expression expression) { return combineConjuncts(extractConjuncts(expression) .stream() .filter((conjunct) -> !DeterminismEvaluator.isDeterministic(conjunct)) .collect(toImmutableList())); } public static Function<Expression, Expression> expressionOrNullSymbols(final Predicate<Symbol>... nullSymbolScopes) { return expression -> { ImmutableList.Builder<Expression> resultDisjunct = ImmutableList.builder(); resultDisjunct.add(expression); for (Predicate<Symbol> nullSymbolScope : nullSymbolScopes) { List<Symbol> symbols = DependencyExtractor.extractUnique(expression).stream() .filter(nullSymbolScope) .collect(toImmutableList()); if (Iterables.isEmpty(symbols)) { continue; } ImmutableList.Builder<Expression> nullConjuncts = ImmutableList.builder(); for (Symbol symbol : symbols) { nullConjuncts.add(new IsNullPredicate(symbol.toSymbolReference())); } resultDisjunct.add(and(nullConjuncts.build())); } return or(resultDisjunct.build()); }; } /** * Removes duplicate deterministic expressions. Preserves the relative order * of the expressions in the list. */ private static List<Expression> removeDuplicates(List<Expression> expressions) { Set<Expression> seen = new HashSet<>(); ImmutableList.Builder<Expression> result = ImmutableList.builder(); for (Expression expression : expressions) { if (!DeterminismEvaluator.isDeterministic(expression)) { result.add(expression); } else if (!seen.contains(expression)) { result.add(expression); seen.add(expression); } } return result.build(); } public static Expression normalize(Expression expression) { if (expression instanceof NotExpression) { NotExpression not = (NotExpression) expression; if (not.getValue() instanceof ComparisonExpression && ((ComparisonExpression) not.getValue()).getType() != IS_DISTINCT_FROM) { ComparisonExpression comparison = (ComparisonExpression) not.getValue(); return new ComparisonExpression(comparison.getType().negate(), comparison.getLeft(), comparison.getRight()); } if (not.getValue() instanceof NotExpression) { return normalize(((NotExpression) not.getValue()).getValue()); } } return expression; } public static Expression rewriteIdentifiersToSymbolReferences(Expression expression) { return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>() { @Override public Expression rewriteIdentifier(Identifier node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { return new SymbolReference(node.getName()); } @Override public Expression rewriteLambdaExpression(LambdaExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { return new LambdaExpression(node.getArguments(), treeRewriter.rewrite(node.getBody(), context)); } }, expression); } }