/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.sql.analyzer; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.planner.ParameterRewriter; import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; import com.facebook.presto.sql.tree.ArithmeticUnaryExpression; import com.facebook.presto.sql.tree.ArrayConstructor; import com.facebook.presto.sql.tree.AstVisitor; import com.facebook.presto.sql.tree.AtTimeZone; import com.facebook.presto.sql.tree.BetweenPredicate; import com.facebook.presto.sql.tree.BindExpression; import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.CoalesceExpression; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.CurrentTime; import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.ExistsPredicate; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; import com.facebook.presto.sql.tree.Extract; import com.facebook.presto.sql.tree.FieldReference; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.IfExpression; import com.facebook.presto.sql.tree.InListExpression; import com.facebook.presto.sql.tree.InPredicate; import com.facebook.presto.sql.tree.IsNotNullPredicate; import com.facebook.presto.sql.tree.IsNullPredicate; import com.facebook.presto.sql.tree.LambdaExpression; import com.facebook.presto.sql.tree.LikePredicate; import com.facebook.presto.sql.tree.Literal; import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.Node; import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.NullIfExpression; import com.facebook.presto.sql.tree.Parameter; import com.facebook.presto.sql.tree.Row; import com.facebook.presto.sql.tree.SearchedCaseExpression; import com.facebook.presto.sql.tree.SimpleCaseExpression; import com.facebook.presto.sql.tree.SortItem; import com.facebook.presto.sql.tree.SubqueryExpression; import com.facebook.presto.sql.tree.SubscriptExpression; import com.facebook.presto.sql.tree.TryExpression; import com.facebook.presto.sql.tree.WhenClause; import com.facebook.presto.sql.tree.Window; import com.facebook.presto.sql.tree.WindowFrame; import com.google.common.collect.ImmutableList; import javax.annotation.Nullable; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import static com.facebook.presto.sql.NodeUtils.getSortItemsFromOrderBy; import static com.facebook.presto.sql.analyzer.LambdaReferenceExtractor.hasReferencesToLambdaArgument; import static com.facebook.presto.sql.analyzer.ScopeReferenceExtractor.getReferencesToScope; import static com.facebook.presto.sql.analyzer.ScopeReferenceExtractor.hasReferencesToScope; import static com.facebook.presto.sql.analyzer.ScopeReferenceExtractor.isFieldFromScope; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MUST_BE_AGGREGATE_OR_GROUP_BY; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MUST_BE_AGGREGATION_FUNCTION; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NESTED_AGGREGATION; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NESTED_WINDOW; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.REFERENCE_TO_OUTPUT_ATTRIBUTE_WITHIN_ORDER_BY_AGGREGATION; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Objects.requireNonNull; /** * Checks whether an expression is constant with respect to the group */ class AggregationAnalyzer { // fields and expressions in the group by clause private final Set<FieldId> groupingFields; private final List<Expression> expressions; private final Map<Expression, FieldId> columnReferences; private final Metadata metadata; private final Analysis analysis; private final Scope sourceScope; private final Optional<Scope> orderByScope; public static void verifySourceAggregations( List<Expression> groupByExpressions, Scope sourceScope, Expression expression, Metadata metadata, Analysis analysis) { AggregationAnalyzer analyzer = new AggregationAnalyzer(groupByExpressions, sourceScope, Optional.empty(), metadata, analysis); analyzer.analyze(expression); } public static void verifyOrderByAggregations( List<Expression> groupByExpressions, Scope sourceScope, Scope orderByScope, Expression expression, Metadata metadata, Analysis analysis) { AggregationAnalyzer analyzer = new AggregationAnalyzer(groupByExpressions, sourceScope, Optional.of(orderByScope), metadata, analysis); analyzer.analyze(expression); } private AggregationAnalyzer(List<Expression> groupByExpressions, Scope sourceScope, Optional<Scope> orderByScope, Metadata metadata, Analysis analysis) { requireNonNull(groupByExpressions, "groupByExpressions is null"); requireNonNull(sourceScope, "sourceScope is null"); requireNonNull(orderByScope, "orderByScope is null"); requireNonNull(metadata, "metadata is null"); requireNonNull(analysis, "analysis is null"); this.sourceScope = sourceScope; this.orderByScope = orderByScope; this.metadata = metadata; this.analysis = analysis; this.expressions = groupByExpressions.stream() .map(e -> ExpressionTreeRewriter.rewriteWith(new ParameterRewriter(analysis.getParameters()), e)) .collect(toImmutableList()); this.columnReferences = analysis.getColumnReferenceFields(); this.groupingFields = groupByExpressions.stream() .filter(columnReferences::containsKey) .map(columnReferences::get) .collect(toImmutableSet()); this.groupingFields.forEach(fieldId -> { checkState(isFieldFromScope(fieldId, sourceScope), "Grouping field %s should originate from %s", fieldId, sourceScope.getRelationType()); }); } private void analyze(Expression expression) { Visitor visitor = new Visitor(); if (!visitor.process(expression, null)) { throw new SemanticException(MUST_BE_AGGREGATE_OR_GROUP_BY, expression, "'%s' must be an aggregate expression or appear in GROUP BY clause", expression); } } /** * visitor returns true if all expressions are constant with respect to the group. */ private class Visitor extends AstVisitor<Boolean, Void> { @Override protected Boolean visitExpression(Expression node, Void context) { throw new UnsupportedOperationException("aggregation analysis not yet implemented for: " + node.getClass().getName()); } @Override protected Boolean visitAtTimeZone(AtTimeZone node, Void context) { return process(node.getValue(), context); } @Override protected Boolean visitSubqueryExpression(SubqueryExpression node, Void context) { /* * Column reference can resolve to (a) some subquery's scope, (b) a projection (ORDER BY scope), * (c) source scope or (d) outer query scope (effectively a constant). * From AggregationAnalyzer's perspective, only case (c) needs verification. */ getReferencesToScope(node, analysis, sourceScope) .filter(expression -> !isGroupingKey(expression)) .findFirst() .ifPresent(expression -> { throw new SemanticException(MUST_BE_AGGREGATE_OR_GROUP_BY, expression, "Subquery uses '%s' which must appear in GROUP BY clause", expression); }); return true; } @Override protected Boolean visitExists(ExistsPredicate node, Void context) { checkState(node.getSubquery() instanceof SubqueryExpression); return process(node.getSubquery(), context); } @Override protected Boolean visitSubscriptExpression(SubscriptExpression node, Void context) { return process(node.getBase(), context) && process(node.getIndex(), context); } @Override protected Boolean visitArrayConstructor(ArrayConstructor node, Void context) { return node.getValues().stream().allMatch(expression -> process(expression, context)); } @Override protected Boolean visitCast(Cast node, Void context) { return process(node.getExpression(), context); } @Override protected Boolean visitCoalesceExpression(CoalesceExpression node, Void context) { return node.getOperands().stream().allMatch(expression -> process(expression, context)); } @Override protected Boolean visitNullIfExpression(NullIfExpression node, Void context) { return process(node.getFirst(), context) && process(node.getSecond(), context); } @Override protected Boolean visitExtract(Extract node, Void context) { return process(node.getExpression(), context); } @Override protected Boolean visitBetweenPredicate(BetweenPredicate node, Void context) { return process(node.getMin(), context) && process(node.getValue(), context) && process(node.getMax(), context); } @Override protected Boolean visitCurrentTime(CurrentTime node, Void context) { return true; } @Override protected Boolean visitArithmeticBinary(ArithmeticBinaryExpression node, Void context) { return process(node.getLeft(), context) && process(node.getRight(), context); } @Override protected Boolean visitComparisonExpression(ComparisonExpression node, Void context) { return process(node.getLeft(), context) && process(node.getRight(), context); } @Override protected Boolean visitLiteral(Literal node, Void context) { return true; } @Override protected Boolean visitIsNotNullPredicate(IsNotNullPredicate node, Void context) { return process(node.getValue(), context); } @Override protected Boolean visitIsNullPredicate(IsNullPredicate node, Void context) { return process(node.getValue(), context); } @Override protected Boolean visitLikePredicate(LikePredicate node, Void context) { return process(node.getValue(), context) && process(node.getPattern(), context); } @Override protected Boolean visitInListExpression(InListExpression node, Void context) { return node.getValues().stream().allMatch(expression -> process(expression, context)); } @Override protected Boolean visitInPredicate(InPredicate node, Void context) { return process(node.getValue(), context) && process(node.getValueList(), context); } @Override protected Boolean visitFunctionCall(FunctionCall node, Void context) { if (metadata.isAggregationFunction(node.getName())) { if (!node.getWindow().isPresent()) { AggregateExtractor aggregateExtractor = new AggregateExtractor(metadata.getFunctionRegistry()); WindowFunctionExtractor windowExtractor = new WindowFunctionExtractor(); for (Expression argument : node.getArguments()) { aggregateExtractor.process(argument, null); windowExtractor.process(argument, null); } if (!aggregateExtractor.getAggregates().isEmpty()) { throw new SemanticException(NESTED_AGGREGATION, node, "Cannot nest aggregations inside aggregation '%s': %s", node.getName(), aggregateExtractor.getAggregates()); } if (!windowExtractor.getWindowFunctions().isEmpty()) { throw new SemanticException(NESTED_WINDOW, node, "Cannot nest window functions inside aggregation '%s': %s", node.getName(), windowExtractor.getWindowFunctions()); } if (node.getFilter().isPresent() && node.isDistinct()) { throw new SemanticException(NOT_SUPPORTED, node, "Filtered aggregations not supported with DISTINCT: '%s'", node); } // ensure that no output fields are referenced from ORDER BY clause if (orderByScope.isPresent()) { node.getArguments().stream().forEach(AggregationAnalyzer.this::verifyNoOrderByReferencesToOutputColumns); } return true; } } else if (node.getFilter().isPresent()) { throw new SemanticException(MUST_BE_AGGREGATION_FUNCTION, node, "Filter is only valid for aggregation functions", node); } if (node.getWindow().isPresent() && !process(node.getWindow().get(), context)) { return false; } return node.getArguments().stream().allMatch(expression -> process(expression, context)); } @Override protected Boolean visitLambdaExpression(LambdaExpression node, Void context) { return process(node.getBody(), context); } @Override protected Boolean visitBindExpression(BindExpression node, Void context) { return process(node.getValue(), context) && process(node.getFunction(), context); } @Override public Boolean visitWindow(Window node, Void context) { for (Expression expression : node.getPartitionBy()) { if (!process(expression, context)) { throw new SemanticException(MUST_BE_AGGREGATE_OR_GROUP_BY, expression, "PARTITION BY expression '%s' must be an aggregate expression or appear in GROUP BY clause", expression); } } for (SortItem sortItem : getSortItemsFromOrderBy(node.getOrderBy())) { Expression expression = sortItem.getSortKey(); if (!process(expression, context)) { throw new SemanticException(MUST_BE_AGGREGATE_OR_GROUP_BY, expression, "ORDER BY expression '%s' must be an aggregate expression or appear in GROUP BY clause", expression); } } if (node.getFrame().isPresent()) { process(node.getFrame().get(), context); } return true; } @Override public Boolean visitWindowFrame(WindowFrame node, Void context) { Optional<Expression> start = node.getStart().getValue(); if (start.isPresent()) { if (!process(start.get(), context)) { throw new SemanticException(MUST_BE_AGGREGATE_OR_GROUP_BY, start.get(), "Window frame start must be an aggregate expression or appear in GROUP BY clause"); } } if (node.getEnd().isPresent() && node.getEnd().get().getValue().isPresent()) { Expression endValue = node.getEnd().get().getValue().get(); if (!process(endValue, context)) { throw new SemanticException(MUST_BE_AGGREGATE_OR_GROUP_BY, endValue, "Window frame end must be an aggregate expression or appear in GROUP BY clause"); } } return true; } @Override protected Boolean visitIdentifier(Identifier node, Void context) { if (analysis.getLambdaArgumentReferences().containsKey(node)) { return true; } return isGroupingKey(node); } @Override protected Boolean visitDereferenceExpression(DereferenceExpression node, Void context) { if (columnReferences.containsKey(node)) { return isGroupingKey(node); } // Allow SELECT col1.f1 FROM table1 GROUP BY col1 return process(node.getBase(), context); } private boolean isGroupingKey(Expression node) { FieldId fieldId = columnReferences.get(node); requireNonNull(fieldId, () -> "No FieldId for " + node); if (orderByScope.isPresent() && isFieldFromScope(fieldId, orderByScope.get())) { return true; } return groupingFields.contains(fieldId); } @Override protected Boolean visitFieldReference(FieldReference node, Void context) { if (orderByScope.isPresent()) { return true; } FieldId fieldId = requireNonNull(columnReferences.get(node), "No FieldId for FieldReference"); boolean inGroup = groupingFields.contains(fieldId); if (!inGroup) { Field field = sourceScope.getRelationType().getFieldByIndex(node.getFieldIndex()); String column; if (!field.getName().isPresent()) { column = Integer.toString(node.getFieldIndex() + 1); } else if (field.getRelationAlias().isPresent()) { column = String.format("'%s.%s'", field.getRelationAlias().get(), field.getName().get()); } else { column = "'" + field.getName().get() + "'"; } throw new SemanticException(MUST_BE_AGGREGATE_OR_GROUP_BY, node, "Column %s not in GROUP BY clause", column); } return inGroup; } @Override protected Boolean visitArithmeticUnary(ArithmeticUnaryExpression node, Void context) { return process(node.getValue(), context); } @Override protected Boolean visitNotExpression(NotExpression node, Void context) { return process(node.getValue(), context); } @Override protected Boolean visitLogicalBinaryExpression(LogicalBinaryExpression node, Void context) { return process(node.getLeft(), context) && process(node.getRight(), context); } @Override protected Boolean visitIfExpression(IfExpression node, Void context) { ImmutableList.Builder<Expression> expressions = ImmutableList.<Expression>builder() .add(node.getCondition()) .add(node.getTrueValue()); if (node.getFalseValue().isPresent()) { expressions.add(node.getFalseValue().get()); } return expressions.build().stream().allMatch(expression -> process(expression, context)); } @Override protected Boolean visitSimpleCaseExpression(SimpleCaseExpression node, Void context) { if (!process(node.getOperand(), context)) { return false; } for (WhenClause whenClause : node.getWhenClauses()) { if (!process(whenClause.getOperand(), context) || !process(whenClause.getResult(), context)) { return false; } } if (node.getDefaultValue().isPresent() && !process(node.getDefaultValue().get(), context)) { return false; } return true; } @Override protected Boolean visitSearchedCaseExpression(SearchedCaseExpression node, Void context) { for (WhenClause whenClause : node.getWhenClauses()) { if (!process(whenClause.getOperand(), context) || !process(whenClause.getResult(), context)) { return false; } } return !node.getDefaultValue().isPresent() || process(node.getDefaultValue().get(), context); } @Override protected Boolean visitTryExpression(TryExpression node, Void context) { return process(node.getInnerExpression(), context); } @Override public Boolean visitRow(Row node, final Void context) { return node.getItems().stream() .allMatch(item -> process(item, context)); } @Override public Boolean visitParameter(Parameter node, Void context) { if (analysis.isDescribe()) { return true; } List<Expression> parameters = analysis.getParameters(); checkArgument(node.getPosition() < parameters.size(), "Invalid parameter number %s, max values is %s", node.getPosition(), parameters.size() - 1); return process(parameters.get(node.getPosition()), context); } @Override public Boolean process(Node node, @Nullable Void context) { if (expressions.stream().anyMatch(node::equals) && (!orderByScope.isPresent() || !hasOrderByReferencesToOutputColumns(node)) && !hasReferencesToLambdaArgument(node, analysis)) { return true; } return super.process(node, context); } } private boolean hasOrderByReferencesToOutputColumns(Node node) { return hasReferencesToScope(node, analysis, orderByScope.get()); } private void verifyNoOrderByReferencesToOutputColumns(Node node) { getReferencesToScope(node, analysis, orderByScope.get()) .findFirst() .ifPresent(expression -> { throw new SemanticException(REFERENCE_TO_OUTPUT_ATTRIBUTE_WITHIN_ORDER_BY_AGGREGATION, expression, "Invalid reference to output projection attribute from ORDER BY aggregation"); }); } }