/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.analyzer;
import com.facebook.presto.Session;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.OperatorNotFoundException;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.security.AccessControl;
import com.facebook.presto.security.DenyAllAccessControl;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.StandardErrorCode;
import com.facebook.presto.spi.function.OperatorType;
import com.facebook.presto.spi.type.CharType;
import com.facebook.presto.spi.type.DecimalParseResult;
import com.facebook.presto.spi.type.Decimals;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeManager;
import com.facebook.presto.spi.type.TypeSignatureParameter;
import com.facebook.presto.spi.type.VarcharType;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.tree.ArithmeticBinaryExpression;
import com.facebook.presto.sql.tree.ArithmeticUnaryExpression;
import com.facebook.presto.sql.tree.ArrayConstructor;
import com.facebook.presto.sql.tree.AtTimeZone;
import com.facebook.presto.sql.tree.BetweenPredicate;
import com.facebook.presto.sql.tree.BinaryLiteral;
import com.facebook.presto.sql.tree.BindExpression;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.CharLiteral;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.CurrentTime;
import com.facebook.presto.sql.tree.DecimalLiteral;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.DoubleLiteral;
import com.facebook.presto.sql.tree.ExistsPredicate;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.Extract;
import com.facebook.presto.sql.tree.FieldReference;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.GenericLiteral;
import com.facebook.presto.sql.tree.Identifier;
import com.facebook.presto.sql.tree.IfExpression;
import com.facebook.presto.sql.tree.InListExpression;
import com.facebook.presto.sql.tree.InPredicate;
import com.facebook.presto.sql.tree.IntervalLiteral;
import com.facebook.presto.sql.tree.IsNotNullPredicate;
import com.facebook.presto.sql.tree.IsNullPredicate;
import com.facebook.presto.sql.tree.LambdaArgumentDeclaration;
import com.facebook.presto.sql.tree.LambdaExpression;
import com.facebook.presto.sql.tree.LikePredicate;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullIfExpression;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.Parameter;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.QuantifiedComparisonExpression;
import com.facebook.presto.sql.tree.Row;
import com.facebook.presto.sql.tree.SearchedCaseExpression;
import com.facebook.presto.sql.tree.SimpleCaseExpression;
import com.facebook.presto.sql.tree.SortItem;
import com.facebook.presto.sql.tree.StackableAstVisitor;
import com.facebook.presto.sql.tree.StringLiteral;
import com.facebook.presto.sql.tree.SubqueryExpression;
import com.facebook.presto.sql.tree.SubscriptExpression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.sql.tree.TimeLiteral;
import com.facebook.presto.sql.tree.TimestampLiteral;
import com.facebook.presto.sql.tree.TryExpression;
import com.facebook.presto.sql.tree.WhenClause;
import com.facebook.presto.sql.tree.WindowFrame;
import com.facebook.presto.type.FunctionType;
import com.facebook.presto.type.RowType;
import com.facebook.presto.util.maps.IdentityLinkedHashMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.SliceUtf8;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import static com.facebook.presto.spi.function.OperatorType.SUBSCRIPT;
import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.BooleanType.BOOLEAN;
import static com.facebook.presto.spi.type.DateType.DATE;
import static com.facebook.presto.spi.type.DoubleType.DOUBLE;
import static com.facebook.presto.spi.type.IntegerType.INTEGER;
import static com.facebook.presto.spi.type.RealType.REAL;
import static com.facebook.presto.spi.type.SmallintType.SMALLINT;
import static com.facebook.presto.spi.type.TimeType.TIME;
import static com.facebook.presto.spi.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE;
import static com.facebook.presto.spi.type.TimestampType.TIMESTAMP;
import static com.facebook.presto.spi.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE;
import static com.facebook.presto.spi.type.TinyintType.TINYINT;
import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY;
import static com.facebook.presto.spi.type.VarcharType.VARCHAR;
import static com.facebook.presto.sql.NodeUtils.getSortItemsFromOrderBy;
import static com.facebook.presto.sql.analyzer.Analyzer.verifyNoAggregatesOrWindowFunctions;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.EXPRESSION_NOT_CONSTANT;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_LITERAL;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PARAMETER_USAGE;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MULTIPLE_FIELDS_FROM_SUBQUERY;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.STANDALONE_LAMBDA;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TYPE_MISMATCH;
import static com.facebook.presto.sql.analyzer.SemanticExceptions.missingAttributeException;
import static com.facebook.presto.sql.tree.Extract.Field.TIMEZONE_HOUR;
import static com.facebook.presto.sql.tree.Extract.Field.TIMEZONE_MINUTE;
import static com.facebook.presto.type.ArrayParametricType.ARRAY;
import static com.facebook.presto.type.IntervalDayTimeType.INTERVAL_DAY_TIME;
import static com.facebook.presto.type.IntervalYearMonthType.INTERVAL_YEAR_MONTH;
import static com.facebook.presto.type.JsonType.JSON;
import static com.facebook.presto.type.RowType.RowField;
import static com.facebook.presto.type.UnknownType.UNKNOWN;
import static com.facebook.presto.util.DateTimeUtils.parseTimestampLiteral;
import static com.facebook.presto.util.DateTimeUtils.timeHasTimeZone;
import static com.facebook.presto.util.DateTimeUtils.timestampHasTimeZone;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.lang.String.format;
import static java.util.Collections.newSetFromMap;
import static java.util.Objects.requireNonNull;
public class ExpressionAnalyzer
{
private final FunctionRegistry functionRegistry;
private final TypeManager typeManager;
private final Function<Node, StatementAnalyzer> statementAnalyzerFactory;
private final Map<Symbol, Type> symbolTypes;
private final boolean isDescribe;
private final IdentityLinkedHashMap<FunctionCall, Signature> resolvedFunctions = new IdentityLinkedHashMap<>();
private final Set<SubqueryExpression> scalarSubqueries = newSetFromMap(new IdentityLinkedHashMap<>());
private final Set<ExistsPredicate> existsSubqueries = newSetFromMap(new IdentityLinkedHashMap<>());
private final IdentityLinkedHashMap<Expression, Type> expressionCoercions = new IdentityLinkedHashMap<>();
private final Set<Expression> typeOnlyCoercions = newSetFromMap(new IdentityLinkedHashMap<>());
private final Set<InPredicate> subqueryInPredicates = newSetFromMap(new IdentityLinkedHashMap<>());
private final IdentityLinkedHashMap<Expression, FieldId> columnReferences = new IdentityLinkedHashMap<>();
private final IdentityLinkedHashMap<Expression, Type> expressionTypes = new IdentityLinkedHashMap<>();
private final Set<QuantifiedComparisonExpression> quantifiedComparisons = newSetFromMap(new IdentityLinkedHashMap<>());
// For lambda argument references, maps each QualifiedNameReference to the referenced LambdaArgumentDeclaration
private final IdentityLinkedHashMap<Identifier, LambdaArgumentDeclaration> lambdaArgumentReferences = new IdentityLinkedHashMap<>();
private final Session session;
private final List<Expression> parameters;
public ExpressionAnalyzer(
FunctionRegistry functionRegistry,
TypeManager typeManager,
Function<Node, StatementAnalyzer> statementAnalyzerFactory,
Session session,
Map<Symbol, Type> symbolTypes,
List<Expression> parameters,
boolean isDescribe)
{
this.functionRegistry = requireNonNull(functionRegistry, "functionRegistry is null");
this.typeManager = requireNonNull(typeManager, "typeManager is null");
this.statementAnalyzerFactory = requireNonNull(statementAnalyzerFactory, "statementAnalyzerFactory is null");
this.session = requireNonNull(session, "session is null");
this.symbolTypes = ImmutableMap.copyOf(requireNonNull(symbolTypes, "symbolTypes is null"));
this.parameters = requireNonNull(parameters, "parameters is null");
this.isDescribe = isDescribe;
}
public IdentityLinkedHashMap<FunctionCall, Signature> getResolvedFunctions()
{
return resolvedFunctions;
}
public IdentityLinkedHashMap<Expression, Type> getExpressionTypes()
{
return expressionTypes;
}
public Type setExpressionType(Expression expression, Type type)
{
requireNonNull(expression, "expression cannot be null");
requireNonNull(type, "type cannot be null");
expressionTypes.put(expression, type);
return type;
}
public IdentityLinkedHashMap<Expression, Type> getExpressionCoercions()
{
return expressionCoercions;
}
public Set<Expression> getTypeOnlyCoercions()
{
return typeOnlyCoercions;
}
public Set<InPredicate> getSubqueryInPredicates()
{
return subqueryInPredicates;
}
public IdentityLinkedHashMap<Expression, FieldId> getColumnReferences()
{
return columnReferences;
}
public IdentityLinkedHashMap<Identifier, LambdaArgumentDeclaration> getLambdaArgumentReferences()
{
return lambdaArgumentReferences;
}
public Type analyze(Expression expression, Scope scope)
{
Visitor visitor = new Visitor(scope);
return visitor.process(expression, new StackableAstVisitor.StackableAstVisitorContext<>(Context.notInLambda()));
}
private Type analyze(Expression expression, Scope scope, Context context)
{
Visitor visitor = new Visitor(scope);
return visitor.process(expression, new StackableAstVisitor.StackableAstVisitorContext<>(context));
}
public Set<SubqueryExpression> getScalarSubqueries()
{
return scalarSubqueries;
}
public Set<ExistsPredicate> getExistsSubqueries()
{
return existsSubqueries;
}
public Set<QuantifiedComparisonExpression> getQuantifiedComparisons()
{
return quantifiedComparisons;
}
private class Visitor
extends StackableAstVisitor<Type, Context>
{
private final Scope scope;
private Visitor(Scope scope)
{
this.scope = requireNonNull(scope, "scope is null");
}
@SuppressWarnings("SuspiciousMethodCalls")
@Override
public Type process(Node node, @Nullable StackableAstVisitorContext<Context> context)
{
// don't double process a node
Type type = expressionTypes.get(node);
if (type != null) {
return type;
}
return super.process(node, context);
}
@Override
protected Type visitRow(Row node, StackableAstVisitorContext<Context> context)
{
List<Type> types = node.getItems().stream()
.map((child) -> process(child, context))
.collect(toImmutableList());
Type type = new RowType(types, Optional.empty());
return setExpressionType(node, type);
}
@Override
protected Type visitCurrentTime(CurrentTime node, StackableAstVisitorContext<Context> context)
{
if (node.getPrecision() != null) {
throw new SemanticException(NOT_SUPPORTED, node, "non-default precision not yet supported");
}
Type type;
switch (node.getType()) {
case DATE:
type = DATE;
break;
case TIME:
type = TIME_WITH_TIME_ZONE;
break;
case LOCALTIME:
type = TIME;
break;
case TIMESTAMP:
type = TIMESTAMP_WITH_TIME_ZONE;
break;
case LOCALTIMESTAMP:
type = TIMESTAMP;
break;
default:
throw new SemanticException(NOT_SUPPORTED, node, "%s not yet supported", node.getType().getName());
}
return setExpressionType(node, type);
}
@Override
protected Type visitSymbolReference(SymbolReference node, StackableAstVisitorContext<Context> context)
{
if (context.getContext().isInLambda()) {
LambdaArgumentDeclaration lambdaArgumentDeclaration = context.getContext().getNameToLambdaArgumentDeclarationMap().get(node.getName());
if (lambdaArgumentDeclaration != null) {
Type result = expressionTypes.get(lambdaArgumentDeclaration);
return setExpressionType(node, result);
}
}
Type type = symbolTypes.get(Symbol.from(node));
return setExpressionType(node, type);
}
@Override
protected Type visitIdentifier(Identifier node, StackableAstVisitorContext<Context> context)
{
if (context.getContext().isInLambda()) {
LambdaArgumentDeclaration lambdaArgumentDeclaration = context.getContext().getNameToLambdaArgumentDeclarationMap().get(node.getName());
if (lambdaArgumentDeclaration != null) {
lambdaArgumentReferences.put(node, lambdaArgumentDeclaration);
Type result = expressionTypes.get(lambdaArgumentDeclaration);
return setExpressionType(node, result);
}
}
return handleResolvedField(node, scope.resolveField(node, QualifiedName.of(node.getName())));
}
private Type handleResolvedField(Expression node, ResolvedField resolvedField)
{
return handleResolvedField(node, FieldId.from(resolvedField), resolvedField.getType());
}
private Type handleResolvedField(Expression node, FieldId fieldId, Type resolvedType)
{
FieldId previous = columnReferences.put(node, fieldId);
checkState(previous == null, "%s already known to refer to %s", node, previous);
return setExpressionType(node, resolvedType);
}
@Override
protected Type visitDereferenceExpression(DereferenceExpression node, StackableAstVisitorContext<Context> context)
{
QualifiedName qualifiedName = DereferenceExpression.getQualifiedName(node);
if (!context.getContext().isInLambda()) {
// If this Dereference looks like column reference, try match it to column first.
if (qualifiedName != null) {
Optional<ResolvedField> resolvedField = scope.tryResolveField(node, qualifiedName);
if (resolvedField.isPresent()) {
return handleResolvedField(node, resolvedField.get());
}
if (!scope.isColumnReference(qualifiedName)) {
throw missingAttributeException(node, qualifiedName);
}
}
}
Type baseType = process(node.getBase(), context);
if (!(baseType instanceof RowType)) {
throw new SemanticException(TYPE_MISMATCH, node.getBase(), "Expression %s is not of type ROW", node.getBase());
}
RowType rowType = (RowType) baseType;
Type rowFieldType = null;
for (RowField rowField : rowType.getFields()) {
if (node.getFieldName().equalsIgnoreCase(rowField.getName().orElse(null))) {
rowFieldType = rowField.getType();
break;
}
}
if (rowFieldType == null) {
throw missingAttributeException(node);
}
return setExpressionType(node, rowFieldType);
}
@Override
protected Type visitNotExpression(NotExpression node, StackableAstVisitorContext<Context> context)
{
coerceType(context, node.getValue(), BOOLEAN, "Value of logical NOT expression");
return setExpressionType(node, BOOLEAN);
}
@Override
protected Type visitLogicalBinaryExpression(LogicalBinaryExpression node, StackableAstVisitorContext<Context> context)
{
coerceType(context, node.getLeft(), BOOLEAN, "Left side of logical expression");
coerceType(context, node.getRight(), BOOLEAN, "Right side of logical expression");
return setExpressionType(node, BOOLEAN);
}
@Override
protected Type visitComparisonExpression(ComparisonExpression node, StackableAstVisitorContext<Context> context)
{
OperatorType operatorType = OperatorType.valueOf(node.getType().name());
return getOperator(context, node, operatorType, node.getLeft(), node.getRight());
}
@Override
protected Type visitIsNullPredicate(IsNullPredicate node, StackableAstVisitorContext<Context> context)
{
process(node.getValue(), context);
return setExpressionType(node, BOOLEAN);
}
@Override
protected Type visitIsNotNullPredicate(IsNotNullPredicate node, StackableAstVisitorContext<Context> context)
{
process(node.getValue(), context);
return setExpressionType(node, BOOLEAN);
}
@Override
protected Type visitNullIfExpression(NullIfExpression node, StackableAstVisitorContext<Context> context)
{
Type firstType = process(node.getFirst(), context);
Type secondType = process(node.getSecond(), context);
if (!typeManager.getCommonSuperType(firstType, secondType).isPresent()) {
throw new SemanticException(TYPE_MISMATCH, node, "Types are not comparable with NULLIF: %s vs %s", firstType, secondType);
}
return setExpressionType(node, firstType);
}
@Override
protected Type visitIfExpression(IfExpression node, StackableAstVisitorContext<Context> context)
{
coerceType(context, node.getCondition(), BOOLEAN, "IF condition");
Type type;
if (node.getFalseValue().isPresent()) {
type = coerceToSingleType(context, node, "Result types for IF must be the same: %s vs %s", node.getTrueValue(), node.getFalseValue().get());
}
else {
type = process(node.getTrueValue(), context);
}
return setExpressionType(node, type);
}
@Override
protected Type visitSearchedCaseExpression(SearchedCaseExpression node, StackableAstVisitorContext<Context> context)
{
for (WhenClause whenClause : node.getWhenClauses()) {
coerceType(context, whenClause.getOperand(), BOOLEAN, "CASE WHEN clause");
}
Type type = coerceToSingleType(context,
"All CASE results must be the same type: %s",
getCaseResultExpressions(node.getWhenClauses(), node.getDefaultValue()));
setExpressionType(node, type);
for (WhenClause whenClause : node.getWhenClauses()) {
Type whenClauseType = process(whenClause.getResult(), context);
setExpressionType(whenClause, whenClauseType);
}
return type;
}
@Override
protected Type visitSimpleCaseExpression(SimpleCaseExpression node, StackableAstVisitorContext<Context> context)
{
for (WhenClause whenClause : node.getWhenClauses()) {
coerceToSingleType(context, whenClause, "CASE operand type does not match WHEN clause operand type: %s vs %s", node.getOperand(), whenClause.getOperand());
}
Type type = coerceToSingleType(context,
"All CASE results must be the same type: %s",
getCaseResultExpressions(node.getWhenClauses(), node.getDefaultValue()));
setExpressionType(node, type);
for (WhenClause whenClause : node.getWhenClauses()) {
Type whenClauseType = process(whenClause.getResult(), context);
setExpressionType(whenClause, whenClauseType);
}
return type;
}
private List<Expression> getCaseResultExpressions(List<WhenClause> whenClauses, Optional<Expression> defaultValue)
{
List<Expression> resultExpressions = new ArrayList<>();
for (WhenClause whenClause : whenClauses) {
resultExpressions.add(whenClause.getResult());
}
defaultValue.ifPresent(resultExpressions::add);
return resultExpressions;
}
@Override
protected Type visitCoalesceExpression(CoalesceExpression node, StackableAstVisitorContext<Context> context)
{
Type type = coerceToSingleType(context, "All COALESCE operands must be the same type: %s", node.getOperands());
return setExpressionType(node, type);
}
@Override
protected Type visitArithmeticUnary(ArithmeticUnaryExpression node, StackableAstVisitorContext<Context> context)
{
switch (node.getSign()) {
case PLUS:
Type type = process(node.getValue(), context);
if (!type.equals(DOUBLE) && !type.equals(REAL) && !type.equals(BIGINT) && !type.equals(INTEGER) && !type.equals(SMALLINT) && !type.equals(TINYINT)) {
// TODO: figure out a type-agnostic way of dealing with this. Maybe add a special unary operator
// that types can chose to implement, or piggyback on the existence of the negation operator
throw new SemanticException(TYPE_MISMATCH, node, "Unary '+' operator cannot by applied to %s type", type);
}
return setExpressionType(node, type);
case MINUS:
return getOperator(context, node, OperatorType.NEGATION, node.getValue());
}
throw new UnsupportedOperationException("Unsupported unary operator: " + node.getSign());
}
@Override
protected Type visitArithmeticBinary(ArithmeticBinaryExpression node, StackableAstVisitorContext<Context> context)
{
return getOperator(context, node, OperatorType.valueOf(node.getType().name()), node.getLeft(), node.getRight());
}
@Override
protected Type visitLikePredicate(LikePredicate node, StackableAstVisitorContext<Context> context)
{
Type valueType = getVarcharType(node.getValue(), context);
Type patternType = getVarcharType(node.getPattern(), context);
coerceType(context, node.getValue(), valueType, "Left side of LIKE expression");
coerceType(context, node.getPattern(), patternType, "Pattern for LIKE expression");
if (node.getEscape() != null) {
Type escapeType = getVarcharType(node.getEscape(), context);
coerceType(context, node.getEscape(), escapeType, "Escape for LIKE expression");
}
return setExpressionType(node, BOOLEAN);
}
private Type getVarcharType(Expression value, StackableAstVisitorContext<Context> context)
{
Type type = process(value, context);
if (!(type instanceof VarcharType)) {
return VARCHAR;
}
return type;
}
@Override
protected Type visitSubscriptExpression(SubscriptExpression node, StackableAstVisitorContext<Context> context)
{
return getOperator(context, node, SUBSCRIPT, node.getBase(), node.getIndex());
}
@Override
protected Type visitArrayConstructor(ArrayConstructor node, StackableAstVisitorContext<Context> context)
{
Type type = coerceToSingleType(context, "All ARRAY elements must be the same type: %s", node.getValues());
Type arrayType = typeManager.getParameterizedType(ARRAY.getName(), ImmutableList.of(TypeSignatureParameter.of(type.getTypeSignature())));
return setExpressionType(node, arrayType);
}
@Override
protected Type visitStringLiteral(StringLiteral node, StackableAstVisitorContext<Context> context)
{
VarcharType type = VarcharType.createVarcharType(SliceUtf8.countCodePoints(node.getSlice()));
return setExpressionType(node, type);
}
@Override
protected Type visitCharLiteral(CharLiteral node, StackableAstVisitorContext<Context> context)
{
CharType type = CharType.createCharType(node.getValue().length());
return setExpressionType(node, type);
}
@Override
protected Type visitBinaryLiteral(BinaryLiteral node, StackableAstVisitorContext<Context> context)
{
return setExpressionType(node, VARBINARY);
}
@Override
protected Type visitLongLiteral(LongLiteral node, StackableAstVisitorContext<Context> context)
{
if (node.getValue() >= Integer.MIN_VALUE && node.getValue() <= Integer.MAX_VALUE) {
return setExpressionType(node, INTEGER);
}
return setExpressionType(node, BIGINT);
}
@Override
protected Type visitDoubleLiteral(DoubleLiteral node, StackableAstVisitorContext<Context> context)
{
return setExpressionType(node, DOUBLE);
}
@Override
protected Type visitDecimalLiteral(DecimalLiteral node, StackableAstVisitorContext<Context> context)
{
DecimalParseResult parseResult = Decimals.parse(node.getValue());
return setExpressionType(node, parseResult.getType());
}
@Override
protected Type visitBooleanLiteral(BooleanLiteral node, StackableAstVisitorContext<Context> context)
{
return setExpressionType(node, BOOLEAN);
}
@Override
protected Type visitGenericLiteral(GenericLiteral node, StackableAstVisitorContext<Context> context)
{
Type type = typeManager.getType(parseTypeSignature(node.getType()));
if (type == null) {
throw new SemanticException(TYPE_MISMATCH, node, "Unknown type: " + node.getType());
}
if (!JSON.equals(type)) {
try {
functionRegistry.getCoercion(VARCHAR, type);
}
catch (IllegalArgumentException e) {
throw new SemanticException(TYPE_MISMATCH, node, "No literal form for type %s", type);
}
}
return setExpressionType(node, type);
}
@Override
protected Type visitTimeLiteral(TimeLiteral node, StackableAstVisitorContext<Context> context)
{
Type type;
if (timeHasTimeZone(node.getValue())) {
type = TIME_WITH_TIME_ZONE;
}
else {
type = TIME;
}
return setExpressionType(node, type);
}
@Override
protected Type visitTimestampLiteral(TimestampLiteral node, StackableAstVisitorContext<Context> context)
{
try {
parseTimestampLiteral(session.getTimeZoneKey(), node.getValue());
}
catch (Exception e) {
throw new SemanticException(INVALID_LITERAL, node, "'%s' is not a valid timestamp literal", node.getValue());
}
Type type;
if (timestampHasTimeZone(node.getValue())) {
type = TIMESTAMP_WITH_TIME_ZONE;
}
else {
type = TIMESTAMP;
}
return setExpressionType(node, type);
}
@Override
protected Type visitIntervalLiteral(IntervalLiteral node, StackableAstVisitorContext<Context> context)
{
Type type;
if (node.isYearToMonth()) {
type = INTERVAL_YEAR_MONTH;
}
else {
type = INTERVAL_DAY_TIME;
}
return setExpressionType(node, type);
}
@Override
protected Type visitNullLiteral(NullLiteral node, StackableAstVisitorContext<Context> context)
{
return setExpressionType(node, UNKNOWN);
}
@Override
protected Type visitFunctionCall(FunctionCall node, StackableAstVisitorContext<Context> context)
{
if (node.getWindow().isPresent()) {
for (Expression expression : node.getWindow().get().getPartitionBy()) {
process(expression, context);
Type type = expressionTypes.get(expression);
if (!type.isComparable()) {
throw new SemanticException(TYPE_MISMATCH, node, "%s is not comparable, and therefore cannot be used in window function PARTITION BY", type);
}
}
for (SortItem sortItem : getSortItemsFromOrderBy(node.getWindow().get().getOrderBy())) {
process(sortItem.getSortKey(), context);
Type type = expressionTypes.get(sortItem.getSortKey());
if (!type.isOrderable()) {
throw new SemanticException(TYPE_MISMATCH, node, "%s is not orderable, and therefore cannot be used in window function ORDER BY", type);
}
}
if (node.getWindow().get().getFrame().isPresent()) {
WindowFrame frame = node.getWindow().get().getFrame().get();
if (frame.getStart().getValue().isPresent()) {
Type type = process(frame.getStart().getValue().get(), context);
if (!type.equals(INTEGER) && !type.equals(BIGINT)) {
throw new SemanticException(TYPE_MISMATCH, node, "Window frame start value type must be INTEGER or BIGINT(actual %s)", type);
}
}
if (frame.getEnd().isPresent() && frame.getEnd().get().getValue().isPresent()) {
Type type = process(frame.getEnd().get().getValue().get(), context);
if (!type.equals(INTEGER) && !type.equals(BIGINT)) {
throw new SemanticException(TYPE_MISMATCH, node, "Window frame end value type must be INTEGER or BIGINT (actual %s)", type);
}
}
}
}
if (node.getFilter().isPresent()) {
Expression expression = node.getFilter().get();
process(expression, context);
}
ImmutableList.Builder<TypeSignatureProvider> argumentTypesBuilder = ImmutableList.builder();
for (Expression expression : node.getArguments()) {
if (expression instanceof LambdaExpression || expression instanceof BindExpression) {
argumentTypesBuilder.add(new TypeSignatureProvider(
types -> {
ExpressionAnalyzer innerExpressionAnalyzer = new ExpressionAnalyzer(
functionRegistry,
typeManager,
statementAnalyzerFactory,
session,
symbolTypes,
parameters,
isDescribe);
if (context.getContext().isInLambda()) {
for (LambdaArgumentDeclaration argument : context.getContext().getNameToLambdaArgumentDeclarationMap().values()) {
innerExpressionAnalyzer.setExpressionType(argument, expressionTypes.get(argument));
}
}
return innerExpressionAnalyzer.analyze(expression, scope, context.getContext().expectingLambda(types)).getTypeSignature();
}));
}
else {
argumentTypesBuilder.add(new TypeSignatureProvider(process(expression, context).getTypeSignature()));
}
}
ImmutableList<TypeSignatureProvider> argumentTypes = argumentTypesBuilder.build();
Signature function;
try {
function = functionRegistry.resolveFunction(node.getName(), argumentTypes);
}
catch (PrestoException e) {
if (e.getErrorCode().getCode() == StandardErrorCode.FUNCTION_NOT_FOUND.toErrorCode().getCode()) {
throw new SemanticException(SemanticErrorCode.FUNCTION_NOT_FOUND, node, e.getMessage());
}
if (e.getErrorCode().getCode() == StandardErrorCode.AMBIGUOUS_FUNCTION_CALL.toErrorCode().getCode()) {
throw new SemanticException(SemanticErrorCode.AMBIGUOUS_FUNCTION_CALL, node, e.getMessage());
}
throw e;
}
for (int i = 0; i < node.getArguments().size(); i++) {
Expression expression = node.getArguments().get(i);
Type expectedType = typeManager.getType(function.getArgumentTypes().get(i));
requireNonNull(expectedType, format("Type %s not found", function.getArgumentTypes().get(i)));
if (node.isDistinct() && !expectedType.isComparable()) {
throw new SemanticException(TYPE_MISMATCH, node, "DISTINCT can only be applied to comparable types (actual: %s)", expectedType);
}
if (argumentTypes.get(i).hasDependency()) {
FunctionType expectedFunctionType = (FunctionType) expectedType;
process(expression, new StackableAstVisitorContext<>(context.getContext().expectingLambda(expectedFunctionType.getArgumentTypes())));
}
else {
Type actualType = typeManager.getType(argumentTypes.get(i).getTypeSignature());
coerceType(expression, actualType, expectedType, format("Function %s argument %d", function, i));
}
}
resolvedFunctions.put(node, function);
Type type = typeManager.getType(function.getReturnType());
return setExpressionType(node, type);
}
@Override
protected Type visitAtTimeZone(AtTimeZone node, StackableAstVisitorContext<Context> context)
{
Type valueType = process(node.getValue(), context);
process(node.getTimeZone(), context);
if (!valueType.equals(TIME_WITH_TIME_ZONE) && !valueType.equals(TIMESTAMP_WITH_TIME_ZONE) && !valueType.equals(TIME) && !valueType.equals(TIMESTAMP)) {
throw new SemanticException(TYPE_MISMATCH, node.getValue(), "Type of value must be a time or timestamp with or without time zone (actual %s)", valueType);
}
Type resultType = valueType;
if (valueType.equals(TIME)) {
resultType = TIME_WITH_TIME_ZONE;
}
else if (valueType.equals(TIMESTAMP)) {
resultType = TIMESTAMP_WITH_TIME_ZONE;
}
return setExpressionType(node, resultType);
}
@Override
protected Type visitParameter(Parameter node, StackableAstVisitorContext<Context> context)
{
if (isDescribe) {
return setExpressionType(node, UNKNOWN);
}
if (parameters.size() == 0) {
throw new SemanticException(INVALID_PARAMETER_USAGE, node, "query takes no parameters");
}
if (node.getPosition() >= parameters.size()) {
throw new SemanticException(INVALID_PARAMETER_USAGE, node, "invalid parameter index %s, max value is %s", node.getPosition(), parameters.size() - 1);
}
Type resultType = process(parameters.get(node.getPosition()), context);
return setExpressionType(node, resultType);
}
@Override
protected Type visitExtract(Extract node, StackableAstVisitorContext<Context> context)
{
Type type = process(node.getExpression(), context);
if (!isDateTimeType(type)) {
throw new SemanticException(TYPE_MISMATCH, node.getExpression(), "Type of argument to extract must be DATE, TIME, TIMESTAMP, or INTERVAL (actual %s)", type);
}
Extract.Field field = node.getField();
if ((field == TIMEZONE_HOUR || field == TIMEZONE_MINUTE) && !(type.equals(TIME_WITH_TIME_ZONE) || type.equals(TIMESTAMP_WITH_TIME_ZONE))) {
throw new SemanticException(TYPE_MISMATCH, node.getExpression(), "Type of argument to extract time zone field must have a time zone (actual %s)", type);
}
return setExpressionType(node, BIGINT);
}
private boolean isDateTimeType(Type type)
{
return type.equals(DATE) ||
type.equals(TIME) ||
type.equals(TIME_WITH_TIME_ZONE) ||
type.equals(TIMESTAMP) ||
type.equals(TIMESTAMP_WITH_TIME_ZONE) ||
type.equals(INTERVAL_DAY_TIME) ||
type.equals(INTERVAL_YEAR_MONTH);
}
@Override
protected Type visitBetweenPredicate(BetweenPredicate node, StackableAstVisitorContext<Context> context)
{
return getOperator(context, node, OperatorType.BETWEEN, node.getValue(), node.getMin(), node.getMax());
}
@Override
public Type visitTryExpression(TryExpression node, StackableAstVisitorContext<Context> context)
{
if (context.getContext().isInLambda()) {
throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "Try expression inside lambda expression is not support yet");
}
Type type = process(node.getInnerExpression(), context);
return setExpressionType(node, type);
}
@Override
public Type visitCast(Cast node, StackableAstVisitorContext<Context> context)
{
Type type = typeManager.getType(parseTypeSignature(node.getType()));
if (type == null) {
throw new SemanticException(TYPE_MISMATCH, node, "Unknown type: " + node.getType());
}
if (type.equals(UNKNOWN)) {
throw new SemanticException(TYPE_MISMATCH, node, "UNKNOWN is not a valid type");
}
Type value = process(node.getExpression(), context);
if (!value.equals(UNKNOWN) && !node.isTypeOnly()) {
try {
functionRegistry.getCoercion(value, type);
}
catch (OperatorNotFoundException e) {
throw new SemanticException(TYPE_MISMATCH, node, "Cannot cast %s to %s", value, type);
}
}
return setExpressionType(node, type);
}
@Override
protected Type visitInPredicate(InPredicate node, StackableAstVisitorContext<Context> context)
{
Expression value = node.getValue();
process(value, context);
Expression valueList = node.getValueList();
process(valueList, context);
if (valueList instanceof InListExpression) {
InListExpression inListExpression = (InListExpression) valueList;
coerceToSingleType(context,
"IN value and list items must be the same type: %s",
ImmutableList.<Expression>builder().add(value).addAll(inListExpression.getValues()).build());
}
else if (valueList instanceof SubqueryExpression) {
coerceToSingleType(context, node, "value and result of subquery must be of the same type for IN expression: %s vs %s", value, valueList);
}
return setExpressionType(node, BOOLEAN);
}
@Override
protected Type visitInListExpression(InListExpression node, StackableAstVisitorContext<Context> context)
{
Type type = coerceToSingleType(context, "All IN list values must be the same type: %s", node.getValues());
setExpressionType(node, type);
return type; // TODO: this really should a be relation type
}
@Override
protected Type visitSubqueryExpression(SubqueryExpression node, StackableAstVisitorContext<Context> context)
{
if (context.getContext().isInLambda()) {
throw new SemanticException(NOT_SUPPORTED, node, "Lambda expression cannot contain subqueries");
}
StatementAnalyzer analyzer = statementAnalyzerFactory.apply(node);
Scope subqueryScope = Scope.builder()
.withParent(scope)
.build();
Scope queryScope = analyzer.analyze(node.getQuery(), subqueryScope);
// Subquery should only produce one column
if (queryScope.getRelationType().getVisibleFieldCount() != 1) {
throw new SemanticException(MULTIPLE_FIELDS_FROM_SUBQUERY,
node,
"Multiple columns returned by subquery are not yet supported. Found %s",
queryScope.getRelationType().getVisibleFieldCount());
}
Node previousNode = context.getPreviousNode().orElse(null);
if (previousNode instanceof InPredicate && ((InPredicate) previousNode).getValue() != node) {
subqueryInPredicates.add((InPredicate) previousNode);
}
else if (previousNode instanceof QuantifiedComparisonExpression) {
quantifiedComparisons.add((QuantifiedComparisonExpression) previousNode);
}
else {
scalarSubqueries.add(node);
}
Type type = getOnlyElement(queryScope.getRelationType().getVisibleFields()).getType();
return setExpressionType(node, type);
}
@Override
protected Type visitExists(ExistsPredicate node, StackableAstVisitorContext<Context> context)
{
StatementAnalyzer analyzer = statementAnalyzerFactory.apply(node);
Scope subqueryScope = Scope.builder().withParent(scope).build();
analyzer.analyze(node.getSubquery(), subqueryScope);
existsSubqueries.add(node);
return setExpressionType(node, BOOLEAN);
}
@Override
protected Type visitQuantifiedComparisonExpression(QuantifiedComparisonExpression node, StackableAstVisitorContext<Context> context)
{
Expression value = node.getValue();
process(value, context);
Expression subquery = node.getSubquery();
process(subquery, context);
Type comparisonType = coerceToSingleType(context, node, "Value expression and result of subquery must be of the same type for quantified comparison: %s vs %s", value, subquery);
switch (node.getComparisonType()) {
case LESS_THAN:
case LESS_THAN_OR_EQUAL:
case GREATER_THAN:
case GREATER_THAN_OR_EQUAL:
if (!comparisonType.isOrderable()) {
throw new SemanticException(TYPE_MISMATCH, node, "Type [%s] must be orderable in order to be used in quantified comparison", comparisonType);
}
break;
case EQUAL:
case NOT_EQUAL:
if (!comparisonType.isComparable()) {
throw new SemanticException(TYPE_MISMATCH, node, "Type [%s] must be comparable in order to be used in quantified comparison", comparisonType);
}
break;
default:
checkState(false, "Unexpected comparison type: %s", node.getComparisonType());
}
return setExpressionType(node, BOOLEAN);
}
@Override
public Type visitFieldReference(FieldReference node, StackableAstVisitorContext<Context> context)
{
Type type = scope.getRelationType().getFieldByIndex(node.getFieldIndex()).getType();
return handleResolvedField(node, new FieldId(scope.getRelationId(), node.getFieldIndex()), type);
}
@Override
protected Type visitLambdaExpression(LambdaExpression node, StackableAstVisitorContext<Context> context)
{
verifyNoAggregatesOrWindowFunctions(functionRegistry, node.getBody(), "Lambda expression");
if (!context.getContext().isExpectingLambda()) {
throw new SemanticException(STANDALONE_LAMBDA, node, "Lambda expression should always be used inside a function");
}
List<Type> types = context.getContext().getFunctionInputTypes();
List<LambdaArgumentDeclaration> lambdaArguments = node.getArguments();
verify(types.size() == lambdaArguments.size());
Map<String, LambdaArgumentDeclaration> nameToLambdaArgumentDeclarationMap = new HashMap<>();
if (context.getContext().isInLambda()) {
nameToLambdaArgumentDeclarationMap.putAll(context.getContext().getNameToLambdaArgumentDeclarationMap());
}
for (int i = 0; i < lambdaArguments.size(); i++) {
LambdaArgumentDeclaration lambdaArgument = lambdaArguments.get(i);
nameToLambdaArgumentDeclarationMap.put(lambdaArgument.getName(), lambdaArgument);
setExpressionType(lambdaArgument, types.get(i));
}
Type returnType = process(node.getBody(), new StackableAstVisitorContext<Context>(Context.inLambda(nameToLambdaArgumentDeclarationMap)));
FunctionType functionType = new FunctionType(types, returnType);
return setExpressionType(node, functionType);
}
@Override
protected Type visitBindExpression(BindExpression node, StackableAstVisitorContext<Context> context)
{
verify(context.getContext().isExpectingLambda(), "bind expression found when lambda is not expected");
List<Type> functionInputTypes = ImmutableList.<Type>builder()
.add(process(node.getValue(), new StackableAstVisitorContext<>(context.getContext().notExpectingLambda())))
.addAll(context.getContext().getFunctionInputTypes())
.build();
FunctionType functionType = (FunctionType) process(node.getFunction(), new StackableAstVisitorContext<>(context.getContext().expectingLambda(functionInputTypes)));
List<Type> argumentTypes = functionType.getArgumentTypes();
verify(argumentTypes.size() == functionInputTypes.size());
verify(functionInputTypes.get(0) == argumentTypes.get(0));
FunctionType result = new FunctionType(argumentTypes.subList(1, argumentTypes.size()), functionType.getReturnType());
return setExpressionType(node, result);
}
@Override
protected Type visitExpression(Expression node, StackableAstVisitorContext<Context> context)
{
throw new SemanticException(NOT_SUPPORTED, node, "not yet implemented: " + node.getClass().getName());
}
@Override
protected Type visitNode(Node node, StackableAstVisitorContext<Context> context)
{
throw new SemanticException(NOT_SUPPORTED, node, "not yet implemented: " + node.getClass().getName());
}
private Type getOperator(StackableAstVisitorContext<Context> context, Expression node, OperatorType operatorType, Expression... arguments)
{
ImmutableList.Builder<Type> argumentTypes = ImmutableList.builder();
for (Expression expression : arguments) {
argumentTypes.add(process(expression, context));
}
Signature operatorSignature;
try {
operatorSignature = functionRegistry.resolveOperator(operatorType, argumentTypes.build());
}
catch (OperatorNotFoundException e) {
throw new SemanticException(TYPE_MISMATCH, node, "%s", e.getMessage());
}
catch (PrestoException e) {
if (e.getErrorCode().getCode() == StandardErrorCode.AMBIGUOUS_FUNCTION_CALL.toErrorCode().getCode()) {
throw new SemanticException(SemanticErrorCode.AMBIGUOUS_FUNCTION_CALL, node, e.getMessage());
}
throw e;
}
for (int i = 0; i < arguments.length; i++) {
Expression expression = arguments[i];
Type type = typeManager.getType(operatorSignature.getArgumentTypes().get(i));
coerceType(context, expression, type, format("Operator %s argument %d", operatorSignature, i));
}
Type type = typeManager.getType(operatorSignature.getReturnType());
return setExpressionType(node, type);
}
private void coerceType(Expression expression, Type actualType, Type expectedType, String message)
{
if (!actualType.equals(expectedType)) {
if (!typeManager.canCoerce(actualType, expectedType)) {
throw new SemanticException(TYPE_MISMATCH, expression, message + " must evaluate to a %s (actual: %s)", expectedType, actualType);
}
addOrReplaceExpressionCoercion(expression, actualType, expectedType);
}
}
private void coerceType(StackableAstVisitorContext<Context> context, Expression expression, Type expectedType, String message)
{
Type actualType = process(expression, context);
coerceType(expression, actualType, expectedType, message);
}
private Type coerceToSingleType(StackableAstVisitorContext<Context> context, Node node, String message, Expression first, Expression second)
{
Type firstType = UNKNOWN;
if (first != null) {
firstType = process(first, context);
}
Type secondType = UNKNOWN;
if (second != null) {
secondType = process(second, context);
}
// coerce types if possible
Optional<Type> superTypeOptional = typeManager.getCommonSuperType(firstType, secondType);
if (superTypeOptional.isPresent()
&& typeManager.canCoerce(firstType, superTypeOptional.get())
&& typeManager.canCoerce(secondType, superTypeOptional.get())) {
Type superType = superTypeOptional.get();
if (!firstType.equals(superType)) {
addOrReplaceExpressionCoercion(first, firstType, superType);
}
if (!secondType.equals(superType)) {
addOrReplaceExpressionCoercion(second, secondType, superType);
}
return superType;
}
throw new SemanticException(TYPE_MISMATCH, node, message, firstType, secondType);
}
private Type coerceToSingleType(StackableAstVisitorContext<Context> context, String message, List<Expression> expressions)
{
// determine super type
Type superType = UNKNOWN;
for (Expression expression : expressions) {
Optional<Type> newSuperType = typeManager.getCommonSuperType(superType, process(expression, context));
if (!newSuperType.isPresent()) {
throw new SemanticException(TYPE_MISMATCH, expression, message, superType);
}
superType = newSuperType.get();
}
// verify all expressions can be coerced to the superType
for (Expression expression : expressions) {
Type type = process(expression, context);
if (!type.equals(superType)) {
if (!typeManager.canCoerce(type, superType)) {
throw new SemanticException(TYPE_MISMATCH, expression, message, superType);
}
addOrReplaceExpressionCoercion(expression, type, superType);
}
}
return superType;
}
private void addOrReplaceExpressionCoercion(Expression expression, Type type, Type superType)
{
expressionCoercions.put(expression, superType);
if (typeManager.isTypeOnlyCoercion(type, superType)) {
typeOnlyCoercions.add(expression);
}
else if (typeOnlyCoercions.contains(expression)) {
typeOnlyCoercions.remove(expression);
}
}
}
private static class Context
{
// functionInputTypes and nameToLambdaDeclarationMap can be null or non-null independently. All 4 combinations are possible.
// The list of types when expecting a lambda (i.e. processing lambda parameters of a function); null otherwise.
// Empty list represents expecting a lambda with no arguments.
private final List<Type> functionInputTypes;
// The mapping from names to corresponding lambda argument declarations when inside a lambda; null otherwise.
// Empty map means that the all lambda expressions surrounding the current node has no arguments.
private final Map<String, LambdaArgumentDeclaration> nameToLambdaArgumentDeclarationMap;
private Context(
List<Type> functionInputTypes,
Map<String, LambdaArgumentDeclaration> nameToLambdaArgumentDeclarationMap)
{
this.functionInputTypes = functionInputTypes;
this.nameToLambdaArgumentDeclarationMap = nameToLambdaArgumentDeclarationMap;
}
public static Context notInLambda()
{
return new Context(null, null);
}
public static Context inLambda(Map<String, LambdaArgumentDeclaration> nameToLambdaArgumentDeclarationMap)
{
return new Context(null, requireNonNull(nameToLambdaArgumentDeclarationMap, "nameToLambdaArgumentDeclarationMap is null"));
}
public Context expectingLambda(List<Type> functionInputTypes)
{
return new Context(requireNonNull(functionInputTypes, "functionInputTypes is null"), this.nameToLambdaArgumentDeclarationMap);
}
public Context notExpectingLambda()
{
return new Context(null, this.nameToLambdaArgumentDeclarationMap);
}
public boolean isInLambda()
{
return nameToLambdaArgumentDeclarationMap != null;
}
public boolean isExpectingLambda()
{
return functionInputTypes != null;
}
public Map<String, LambdaArgumentDeclaration> getNameToLambdaArgumentDeclarationMap()
{
checkState(isInLambda());
return nameToLambdaArgumentDeclarationMap;
}
public List<Type> getFunctionInputTypes()
{
checkState(isExpectingLambda());
return functionInputTypes;
}
}
public static IdentityLinkedHashMap<Expression, Type> getExpressionTypes(
Session session,
Metadata metadata,
SqlParser sqlParser,
Map<Symbol, Type> types,
Expression expression,
List<Expression> parameters)
{
return getExpressionTypes(session, metadata, sqlParser, types, expression, parameters, false);
}
public static IdentityLinkedHashMap<Expression, Type> getExpressionTypes(
Session session,
Metadata metadata,
SqlParser sqlParser,
Map<Symbol, Type> types,
Expression expression,
List<Expression> parameters,
boolean isDescribe)
{
return getExpressionTypes(session, metadata, sqlParser, types, ImmutableList.of(expression), parameters, isDescribe);
}
public static IdentityLinkedHashMap<Expression, Type> getExpressionTypes(
Session session,
Metadata metadata,
SqlParser sqlParser,
Map<Symbol, Type> types,
Iterable<Expression> expressions,
List<Expression> parameters,
boolean isDescribe)
{
return analyzeExpressionsWithSymbols(session, metadata, sqlParser, types, expressions, parameters, isDescribe).getExpressionTypes();
}
public static IdentityLinkedHashMap<Expression, Type> getExpressionTypesFromInput(
Session session,
Metadata metadata,
SqlParser sqlParser,
Map<Integer, Type> types,
Expression expression,
List<Expression> parameters)
{
return getExpressionTypesFromInput(session, metadata, sqlParser, types, ImmutableList.of(expression), parameters);
}
public static IdentityLinkedHashMap<Expression, Type> getExpressionTypesFromInput(
Session session,
Metadata metadata,
SqlParser sqlParser,
Map<Integer, Type> types,
Iterable<Expression> expressions,
List<Expression> parameters)
{
return analyzeExpressionsWithInputs(session, metadata, sqlParser, types, expressions, parameters).getExpressionTypes();
}
public static ExpressionAnalysis analyzeExpressionsWithSymbols(
Session session,
Metadata metadata,
SqlParser sqlParser,
Map<Symbol, Type> types,
Iterable<Expression> expressions,
List<Expression> parameters,
boolean isDescribe)
{
return analyzeExpressions(session, metadata, sqlParser, new RelationType(), types, expressions, parameters, isDescribe);
}
private static ExpressionAnalysis analyzeExpressionsWithInputs(
Session session,
Metadata metadata,
SqlParser sqlParser,
Map<Integer, Type> types,
Iterable<Expression> expressions,
List<Expression> parameters)
{
Field[] fields = new Field[types.size()];
for (Entry<Integer, Type> entry : types.entrySet()) {
fields[entry.getKey()] = Field.newUnqualified(Optional.empty(), entry.getValue());
}
RelationType tupleDescriptor = new RelationType(fields);
return analyzeExpressions(session, metadata, sqlParser, tupleDescriptor, ImmutableMap.of(), expressions, parameters);
}
public static ExpressionAnalysis analyzeExpressions(
Session session,
Metadata metadata,
SqlParser sqlParser,
RelationType tupleDescriptor,
Map<Symbol, Type> types,
Iterable<? extends Expression> expressions,
List<Expression> parameters)
{
return analyzeExpressions(session, metadata, sqlParser, tupleDescriptor, types, expressions, parameters, false);
}
private static ExpressionAnalysis analyzeExpressions(
Session session,
Metadata metadata,
SqlParser sqlParser,
RelationType tupleDescriptor,
Map<Symbol, Type> types,
Iterable<? extends Expression> expressions,
List<Expression> parameters,
boolean isDescribe)
{
// expressions at this point can not have sub queries so deny all access checks
// in the future, we will need a full access controller here to verify access to functions
Analysis analysis = new Analysis(null, parameters, isDescribe);
ExpressionAnalyzer analyzer = create(analysis, session, metadata, sqlParser, new DenyAllAccessControl(), types);
for (Expression expression : expressions) {
analyzer.analyze(expression, Scope.builder().withRelationType(RelationId.anonymous(), tupleDescriptor).build());
}
return new ExpressionAnalysis(
analyzer.getExpressionTypes(),
analyzer.getExpressionCoercions(),
analyzer.getSubqueryInPredicates(),
analyzer.getScalarSubqueries(),
analyzer.getExistsSubqueries(),
analyzer.getColumnReferences(),
analyzer.getTypeOnlyCoercions(),
analyzer.getQuantifiedComparisons(),
analyzer.getLambdaArgumentReferences());
}
public static ExpressionAnalysis analyzeExpression(
Session session,
Metadata metadata,
AccessControl accessControl,
SqlParser sqlParser,
Scope scope,
Analysis analysis,
Expression expression)
{
ExpressionAnalyzer analyzer = create(analysis, session, metadata, sqlParser, accessControl, ImmutableMap.of());
analyzer.analyze(expression, scope);
IdentityLinkedHashMap<Expression, Type> expressionTypes = analyzer.getExpressionTypes();
IdentityLinkedHashMap<Expression, Type> expressionCoercions = analyzer.getExpressionCoercions();
Set<Expression> typeOnlyCoercions = analyzer.getTypeOnlyCoercions();
IdentityLinkedHashMap<FunctionCall, Signature> resolvedFunctions = analyzer.getResolvedFunctions();
analysis.addTypes(expressionTypes);
analysis.addCoercions(expressionCoercions, typeOnlyCoercions);
analysis.addFunctionSignatures(resolvedFunctions);
analysis.addColumnReferences(analyzer.getColumnReferences());
analysis.addLambdaArgumentReferences(analyzer.getLambdaArgumentReferences());
return new ExpressionAnalysis(
expressionTypes,
expressionCoercions,
analyzer.getSubqueryInPredicates(),
analyzer.getScalarSubqueries(),
analyzer.getExistsSubqueries(),
analyzer.getColumnReferences(),
analyzer.getTypeOnlyCoercions(),
analyzer.getQuantifiedComparisons(),
analyzer.getLambdaArgumentReferences());
}
public static ExpressionAnalyzer create(
Analysis analysis,
Session session,
Metadata metadata,
SqlParser sqlParser,
AccessControl accessControl)
{
return create(analysis, session, metadata, sqlParser, accessControl, ImmutableMap.of());
}
public static ExpressionAnalyzer create(
Analysis analysis,
Session session,
Metadata metadata,
SqlParser sqlParser,
AccessControl accessControl,
Map<Symbol, Type> types)
{
return new ExpressionAnalyzer(
metadata.getFunctionRegistry(),
metadata.getTypeManager(),
node -> new StatementAnalyzer(analysis, metadata, sqlParser, accessControl, session),
session,
types,
analysis.getParameters(),
analysis.isDescribe());
}
public static ExpressionAnalyzer createConstantAnalyzer(Metadata metadata, Session session, List<Expression> parameters)
{
return createWithoutSubqueries(
metadata.getFunctionRegistry(),
metadata.getTypeManager(),
session,
parameters,
EXPRESSION_NOT_CONSTANT,
"Constant expression cannot contain a subquery",
false);
}
public static ExpressionAnalyzer createConstantAnalyzer(Metadata metadata, Session session, List<Expression> parameters, boolean isDescribe)
{
return createWithoutSubqueries(
metadata.getFunctionRegistry(),
metadata.getTypeManager(),
session,
parameters,
EXPRESSION_NOT_CONSTANT,
"Constant expression cannot contain a subquery",
isDescribe);
}
public static ExpressionAnalyzer createWithoutSubqueries(
FunctionRegistry functionRegistry,
TypeManager typeManager,
Session session,
List<Expression> parameters,
SemanticErrorCode errorCode,
String message,
boolean isDescribe)
{
return new ExpressionAnalyzer(functionRegistry, typeManager, node -> {
throw new SemanticException(errorCode, node, message);
}, session, ImmutableMap.of(), parameters, isDescribe);
}
}