/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner;
import com.facebook.presto.Session;
import com.facebook.presto.client.FailureInfo;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.operator.scalar.ArraySubscriptOperator;
import com.facebook.presto.operator.scalar.ScalarFunctionImplementation;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.RecordCursor;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.block.BlockBuilderStatus;
import com.facebook.presto.spi.block.InterleavedBlockBuilder;
import com.facebook.presto.spi.function.OperatorType;
import com.facebook.presto.spi.type.StandardTypes;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeManager;
import com.facebook.presto.sql.analyzer.ExpressionAnalyzer;
import com.facebook.presto.sql.analyzer.Scope;
import com.facebook.presto.sql.analyzer.SemanticErrorCode;
import com.facebook.presto.sql.analyzer.SemanticException;
import com.facebook.presto.sql.planner.optimizations.CanonicalizeExpressions;
import com.facebook.presto.sql.tree.ArithmeticBinaryExpression;
import com.facebook.presto.sql.tree.ArithmeticUnaryExpression;
import com.facebook.presto.sql.tree.ArrayConstructor;
import com.facebook.presto.sql.tree.AstVisitor;
import com.facebook.presto.sql.tree.BetweenPredicate;
import com.facebook.presto.sql.tree.BindExpression;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.ComparisonExpressionType;
import com.facebook.presto.sql.tree.DefaultTraversalVisitor;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.ExistsPredicate;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.FieldReference;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.Identifier;
import com.facebook.presto.sql.tree.IfExpression;
import com.facebook.presto.sql.tree.InListExpression;
import com.facebook.presto.sql.tree.InPredicate;
import com.facebook.presto.sql.tree.IsNotNullPredicate;
import com.facebook.presto.sql.tree.IsNullPredicate;
import com.facebook.presto.sql.tree.LambdaArgumentDeclaration;
import com.facebook.presto.sql.tree.LambdaExpression;
import com.facebook.presto.sql.tree.LikePredicate;
import com.facebook.presto.sql.tree.Literal;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullIfExpression;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.Parameter;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.QuantifiedComparisonExpression;
import com.facebook.presto.sql.tree.Row;
import com.facebook.presto.sql.tree.SearchedCaseExpression;
import com.facebook.presto.sql.tree.SimpleCaseExpression;
import com.facebook.presto.sql.tree.StringLiteral;
import com.facebook.presto.sql.tree.SubqueryExpression;
import com.facebook.presto.sql.tree.SubscriptExpression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.sql.tree.TryExpression;
import com.facebook.presto.sql.tree.WhenClause;
import com.facebook.presto.type.ArrayType;
import com.facebook.presto.type.FunctionType;
import com.facebook.presto.type.LikeFunctions;
import com.facebook.presto.type.RowType;
import com.facebook.presto.type.RowType.RowField;
import com.facebook.presto.util.Failures;
import com.facebook.presto.util.FastutilSetHelper;
import com.facebook.presto.util.maps.IdentityLinkedHashMap;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Defaults;
import com.google.common.base.Functions;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.primitives.Primitives;
import io.airlift.joni.Regex;
import io.airlift.json.JsonCodec;
import io.airlift.slice.Slice;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED;
import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.spi.type.TypeUtils.writeNativeValue;
import static com.facebook.presto.spi.type.VarcharType.createVarcharType;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.createConstantAnalyzer;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.EXPRESSION_NOT_CONSTANT;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.gen.TryCodeGenerator.tryExpressionExceptionHandler;
import static com.facebook.presto.sql.gen.VarArgsToMapAdapterGenerator.generateVarArgsToMapAdapter;
import static com.facebook.presto.sql.planner.LiteralInterpreter.toExpression;
import static com.facebook.presto.sql.planner.LiteralInterpreter.toExpressions;
import static com.facebook.presto.type.LikeFunctions.isLikePattern;
import static com.facebook.presto.type.LikeFunctions.unescapeLiteralLikePattern;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Predicates.instanceOf;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.any;
import static java.util.Objects.requireNonNull;
public class ExpressionInterpreter
{
private final Expression expression;
private final Metadata metadata;
private final ConnectorSession session;
private final boolean optimize;
private final IdentityLinkedHashMap<Expression, Type> expressionTypes;
private final Visitor visitor;
// identity-based cache for LIKE expressions with constant pattern and escape char
private final IdentityLinkedHashMap<LikePredicate, Regex> likePatternCache = new IdentityLinkedHashMap<>();
private final IdentityLinkedHashMap<InListExpression, Set<?>> inListCache = new IdentityLinkedHashMap<>();
public static ExpressionInterpreter expressionInterpreter(Expression expression, Metadata metadata, Session session, IdentityLinkedHashMap<Expression, Type> expressionTypes)
{
requireNonNull(expression, "expression is null");
requireNonNull(metadata, "metadata is null");
requireNonNull(session, "session is null");
return new ExpressionInterpreter(expression, metadata, session, expressionTypes, false);
}
public static ExpressionInterpreter expressionOptimizer(Expression expression, Metadata metadata, Session session, IdentityLinkedHashMap<Expression, Type> expressionTypes)
{
requireNonNull(expression, "expression is null");
requireNonNull(metadata, "metadata is null");
requireNonNull(session, "session is null");
return new ExpressionInterpreter(expression, metadata, session, expressionTypes, true);
}
public static Object evaluateConstantExpression(Expression expression, Type expectedType, Metadata metadata, Session session, List<Expression> parameters)
{
ExpressionAnalyzer analyzer = createConstantAnalyzer(metadata, session, parameters);
analyzer.analyze(expression, Scope.create());
Type actualType = analyzer.getExpressionTypes().get(expression);
if (!metadata.getTypeManager().canCoerce(actualType, expectedType)) {
throw new SemanticException(SemanticErrorCode.TYPE_MISMATCH, expression, String.format("Cannot cast type %s to %s",
expectedType.getTypeSignature(),
actualType.getTypeSignature()));
}
IdentityLinkedHashMap<Expression, Type> coercions = new IdentityLinkedHashMap<>();
coercions.putAll(analyzer.getExpressionCoercions());
coercions.put(expression, expectedType);
return evaluateConstantExpression(expression, coercions, metadata, session, ImmutableSet.of(), parameters);
}
public static Object evaluateConstantExpression(
Expression expression,
IdentityLinkedHashMap<Expression, Type> coercions,
Metadata metadata, Session session,
Set<Expression> columnReferences,
List<Expression> parameters)
{
requireNonNull(columnReferences, "columnReferences is null");
verifyExpressionIsConstant(columnReferences, expression);
// add coercions
Expression rewrite = ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>()
{
@Override
public Expression rewriteExpression(Expression node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
Expression rewrittenExpression = treeRewriter.defaultRewrite(node, context);
// cast expression if coercion is registered
Type coerceToType = coercions.get(node);
if (coerceToType != null) {
rewrittenExpression = new Cast(rewrittenExpression, coerceToType.getTypeSignature().toString());
}
return rewrittenExpression;
}
}, expression);
// redo the analysis since above expression rewriter might create new expressions which do not have entries in the type map
ExpressionAnalyzer analyzer = createConstantAnalyzer(metadata, session, parameters);
analyzer.analyze(rewrite, Scope.create());
// remove syntax sugar
rewrite = ExpressionTreeRewriter.rewriteWith(new DesugaringRewriter(analyzer.getExpressionTypes()), rewrite);
// expressionInterpreter/optimizer only understands a subset of expression types
// TODO: remove this when the new expression tree is implemented
Expression canonicalized = CanonicalizeExpressions.canonicalizeExpression(rewrite);
// The optimization above may have rewritten the expression tree which breaks all the identity maps, so redo the analysis
// to re-analyze coercions that might be necessary
analyzer = createConstantAnalyzer(metadata, session, parameters);
analyzer.analyze(canonicalized, Scope.create());
// evaluate the expression
Object result = expressionInterpreter(canonicalized, metadata, session, analyzer.getExpressionTypes()).evaluate(0);
verify(!(result instanceof Expression), "Expression interpreter returned an unresolved expression");
return result;
}
public static void verifyExpressionIsConstant(Set<Expression> columnReferences, Expression expression)
{
new ConstantExpressionVerifierVisitor(columnReferences, expression).process(expression, null);
}
private ExpressionInterpreter(Expression expression, Metadata metadata, Session session, IdentityLinkedHashMap<Expression, Type> expressionTypes, boolean optimize)
{
this.expression = expression;
this.metadata = metadata;
this.session = session.toConnectorSession();
this.expressionTypes = expressionTypes;
verify((expressionTypes.containsKey(expression)));
this.optimize = optimize;
this.visitor = new Visitor();
}
public Type getType()
{
return expressionTypes.get(expression);
}
public Object evaluate(RecordCursor inputs)
{
checkState(!optimize, "evaluate(RecordCursor) not allowed for optimizer");
return visitor.process(expression, inputs);
}
public Object evaluate(int position, Block... inputs)
{
checkState(!optimize, "evaluate(int, Block...) not allowed for optimizer");
return visitor.process(expression, new SinglePagePositionContext(position, inputs));
}
public Object evaluate(int leftPosition, Block[] leftBlocks, int rightPosition, Block[] rightBlocks)
{
checkState(!optimize, "evaluate(int, Block[], int, Block[]) not allowed for optimizer");
return visitor.process(expression, new TwoPagesPositionContext(leftPosition, leftBlocks, rightPosition, rightBlocks));
}
public Object optimize(SymbolResolver inputs)
{
checkState(optimize, "evaluate(SymbolResolver) not allowed for interpreter");
return visitor.process(expression, inputs);
}
private static class ConstantExpressionVerifierVisitor
extends DefaultTraversalVisitor<Void, Void>
{
private final Set<Expression> columnReferences;
private final Expression expression;
public ConstantExpressionVerifierVisitor(Set<Expression> columnReferences, Expression expression)
{
this.columnReferences = columnReferences;
this.expression = expression;
}
@Override
protected Void visitDereferenceExpression(DereferenceExpression node, Void context)
{
if (columnReferences.contains(node)) {
throw new SemanticException(EXPRESSION_NOT_CONSTANT, expression, "Constant expression cannot contain column references");
}
process(node.getBase(), context);
return null;
}
@Override
protected Void visitIdentifier(Identifier node, Void context)
{
throw new SemanticException(EXPRESSION_NOT_CONSTANT, expression, "Constant expression cannot contain column references");
}
@Override
protected Void visitFieldReference(FieldReference node, Void context)
{
throw new SemanticException(EXPRESSION_NOT_CONSTANT, expression, "Constant expression cannot contain column references");
}
}
@SuppressWarnings("FloatingPointEquality")
private class Visitor
extends AstVisitor<Object, Object>
{
@Override
public Object visitFieldReference(FieldReference node, Object context)
{
Type type = expressionTypes.get(node);
int channel = node.getFieldIndex();
if (context instanceof PagePositionContext) {
PagePositionContext pagePositionContext = (PagePositionContext) context;
int position = pagePositionContext.getPosition(channel);
Block block = pagePositionContext.getBlock(channel);
if (block.isNull(position)) {
return null;
}
Class<?> javaType = type.getJavaType();
if (javaType == boolean.class) {
return type.getBoolean(block, position);
}
else if (javaType == long.class) {
return type.getLong(block, position);
}
else if (javaType == double.class) {
return type.getDouble(block, position);
}
else if (javaType == Slice.class) {
return type.getSlice(block, position);
}
else if (javaType == Block.class) {
return type.getObject(block, position);
}
else {
throw new UnsupportedOperationException("not yet implemented");
}
}
else if (context instanceof RecordCursor) {
RecordCursor cursor = (RecordCursor) context;
if (cursor.isNull(channel)) {
return null;
}
Class<?> javaType = type.getJavaType();
if (javaType == boolean.class) {
return cursor.getBoolean(channel);
}
else if (javaType == long.class) {
return cursor.getLong(channel);
}
else if (javaType == double.class) {
return cursor.getDouble(channel);
}
else if (javaType == Slice.class) {
return cursor.getSlice(channel);
}
else if (javaType == Block.class) {
return cursor.getObject(channel);
}
else {
throw new UnsupportedOperationException("not yet implemented");
}
}
throw new UnsupportedOperationException("Inputs or cursor myst be set");
}
@Override
protected Object visitDereferenceExpression(DereferenceExpression node, Object context)
{
Type type = expressionTypes.get(node.getBase());
// if there is no type for the base of Dereference, it must be QualifiedName
if (type == null) {
return node;
}
Object base = process(node.getBase(), context);
// if the base part is evaluated to be null, the dereference expression should also be null
if (base == null) {
return null;
}
if (hasUnresolvedValue(base)) {
return new DereferenceExpression(toExpression(base, type), node.getFieldName());
}
RowType rowType = (RowType) type;
Block row = (Block) base;
Type returnType = expressionTypes.get(node);
List<RowField> fields = rowType.getFields();
int index = -1;
for (int i = 0; i < fields.size(); i++) {
RowField field = fields.get(i);
if (field.getName().isPresent() && field.getName().get().equalsIgnoreCase(node.getFieldName())) {
checkArgument(index < 0, "Ambiguous field %s in type %s", field, rowType.getDisplayName());
index = i;
}
}
checkState(index >= 0, "could not find field name: %s", node.getFieldName());
if (row.isNull(index)) {
return null;
}
Class<?> javaType = returnType.getJavaType();
if (javaType == long.class) {
return returnType.getLong(row, index);
}
else if (javaType == double.class) {
return returnType.getDouble(row, index);
}
else if (javaType == boolean.class) {
return returnType.getBoolean(row, index);
}
else if (javaType == Slice.class) {
return returnType.getSlice(row, index);
}
else if (!javaType.isPrimitive()) {
return returnType.getObject(row, index);
}
throw new UnsupportedOperationException("Dereference a unsupported primitive type: " + javaType.getName());
}
@Override
protected Object visitIdentifier(Identifier node, Object context)
{
return node;
}
@Override
protected Object visitParameter(Parameter node, Object context)
{
return node;
}
@Override
protected Object visitSymbolReference(SymbolReference node, Object context)
{
return ((SymbolResolver) context).getValue(Symbol.from(node));
}
@Override
protected Object visitLiteral(Literal node, Object context)
{
return LiteralInterpreter.evaluate(metadata, session, node);
}
@Override
protected Object visitIsNullPredicate(IsNullPredicate node, Object context)
{
Object value = process(node.getValue(), context);
if (value instanceof Expression) {
return new IsNullPredicate(toExpression(value, expressionTypes.get(node.getValue())));
}
return value == null;
}
@Override
protected Object visitIsNotNullPredicate(IsNotNullPredicate node, Object context)
{
Object value = process(node.getValue(), context);
if (value instanceof Expression) {
return new IsNotNullPredicate(toExpression(value, expressionTypes.get(node.getValue())));
}
return value != null;
}
@Override
protected Object visitSearchedCaseExpression(SearchedCaseExpression node, Object context)
{
Object defaultResult = processWithExceptionHandling(node.getDefaultValue().orElse(null), context);
List<WhenClause> whenClauses = new ArrayList<>();
for (WhenClause whenClause : node.getWhenClauses()) {
Object whenOperand = processWithExceptionHandling(whenClause.getOperand(), context);
Object result = processWithExceptionHandling(whenClause.getResult(), context);
if (whenOperand instanceof Expression) {
// cannot fully evaluate, add updated whenClause
whenClauses.add(new WhenClause(
toExpression(whenOperand, type(whenClause.getOperand())),
toExpression(result, type(whenClause.getResult()))));
}
else if (Boolean.TRUE.equals(whenOperand)) {
// condition is true, use this as defaultResult
defaultResult = result;
break;
}
}
if (whenClauses.isEmpty()) {
return defaultResult;
}
Expression resultExpression = (defaultResult == null) ? null : toExpression(defaultResult, type(node));
return new SearchedCaseExpression(whenClauses, Optional.ofNullable(resultExpression));
}
@Override
protected Object visitIfExpression(IfExpression node, Object context)
{
Object trueValue = processWithExceptionHandling(node.getTrueValue(), context);
Object falseValue = processWithExceptionHandling(node.getFalseValue().orElse(null), context);
Object condition = processWithExceptionHandling(node.getCondition(), context);
if (condition instanceof Expression) {
Expression falseValueExpression = (falseValue == null) ? null : toExpression(falseValue, type(node.getFalseValue().get()));
return new IfExpression(
toExpression(condition, type(node.getCondition())),
toExpression(trueValue, type(node.getTrueValue())),
falseValueExpression
);
}
else if (Boolean.TRUE.equals(condition)) {
return trueValue;
}
else {
return falseValue;
}
}
private Object processWithExceptionHandling(Expression expression, Object context)
{
if (expression == null) {
return null;
}
try {
return process(expression, context);
}
catch (RuntimeException e) {
// HACK
// Certain operations like 0 / 0 or likeExpression may throw exceptions.
// Wrap them a FunctionCall that will throw the exception if the expression is actually executed
return createFailureFunction(e, type(expression));
}
}
@Override
protected Object visitSimpleCaseExpression(SimpleCaseExpression node, Object context)
{
Object operand = processWithExceptionHandling(node.getOperand(), context);
Type operandType = type(node.getOperand());
// evaluate defaultClause
Expression defaultClause = node.getDefaultValue().orElse(null);
Object defaultResult = processWithExceptionHandling(defaultClause, context);
// if operand is null, return defaultValue
if (operand == null) {
return defaultResult;
}
List<WhenClause> whenClauses = new ArrayList<>();
for (WhenClause whenClause : node.getWhenClauses()) {
Object whenOperand = processWithExceptionHandling(whenClause.getOperand(), context);
Object result = processWithExceptionHandling(whenClause.getResult(), context);
if (whenOperand instanceof Expression || operand instanceof Expression) {
// cannot fully evaluate, add updated whenClause
whenClauses.add(new WhenClause(
toExpression(whenOperand, type(whenClause.getOperand())),
toExpression(result, type(whenClause.getResult()))));
}
else if (whenOperand != null && isEqual(operand, operandType, whenOperand, type(whenClause.getOperand()))) {
// condition is true, use this as defaultResult
defaultResult = result;
break;
}
}
if (whenClauses.isEmpty()) {
return defaultResult;
}
Expression defaultExpression = (defaultResult == null) ? null : toExpression(defaultResult, type(node));
return new SimpleCaseExpression(toExpression(operand, type(node.getOperand())), whenClauses, Optional.ofNullable(defaultExpression));
}
private boolean isEqual(Object operand1, Type type1, Object operand2, Type type2)
{
return (Boolean) invokeOperator(OperatorType.EQUAL, ImmutableList.of(type1, type2), ImmutableList.of(operand1, operand2));
}
private Type type(Expression expression)
{
return expressionTypes.get(expression);
}
@Override
protected Object visitCoalesceExpression(CoalesceExpression node, Object context)
{
Type type = type(node);
List<Object> values = node.getOperands().stream()
.map(value -> processWithExceptionHandling(value, context))
.filter(value -> value != null)
.collect(Collectors.toList());
if ((!values.isEmpty() && !(values.get(0) instanceof Expression)) || values.size() == 1) {
return values.get(0);
}
List<Expression> expressions = values.stream()
.map(value -> toExpression(value, type))
.collect(Collectors.toList());
if (expressions.isEmpty()) {
return null;
}
return new CoalesceExpression(expressions);
}
@Override
protected Object visitInPredicate(InPredicate node, Object context)
{
Object value = process(node.getValue(), context);
if (value == null) {
return null;
}
Expression valueListExpression = node.getValueList();
if (!(valueListExpression instanceof InListExpression)) {
if (!optimize) {
throw new UnsupportedOperationException("IN predicate value list type not yet implemented: " + valueListExpression.getClass().getName());
}
return node;
}
InListExpression valueList = (InListExpression) valueListExpression;
Set<?> set = inListCache.get(valueList);
// We use the presence of the node in the map to indicate that we've already done
// the analysis below. If the value is null, it means that we can't apply the HashSet
// optimization
if (!inListCache.containsKey(valueList)) {
if (valueList.getValues().stream().allMatch(Literal.class::isInstance) &&
valueList.getValues().stream().noneMatch(NullLiteral.class::isInstance)) {
Set objectSet = valueList.getValues().stream().map(expression -> process(expression, context)).collect(Collectors.toSet());
set = FastutilSetHelper.toFastutilHashSet(objectSet, expressionTypes.get(node.getValue()), metadata.getFunctionRegistry());
}
inListCache.put(valueList, set);
}
if (set != null && !(value instanceof Expression)) {
return set.contains(value);
}
boolean hasUnresolvedValue = false;
if (value instanceof Expression) {
hasUnresolvedValue = true;
}
boolean hasNullValue = false;
boolean found = false;
List<Object> values = new ArrayList<>(valueList.getValues().size());
List<Type> types = new ArrayList<>(valueList.getValues().size());
for (Expression expression : valueList.getValues()) {
Object inValue = process(expression, context);
if (value instanceof Expression || inValue instanceof Expression) {
hasUnresolvedValue = true;
values.add(inValue);
types.add(expressionTypes.get(expression));
continue;
}
if (inValue == null) {
hasNullValue = true;
}
else if (!found && (Boolean) invokeOperator(OperatorType.EQUAL, types(node.getValue(), expression), ImmutableList.of(value, inValue))) {
// in does not short-circuit so we must evaluate all value in the list
found = true;
}
}
if (found) {
return true;
}
if (hasUnresolvedValue) {
Type type = expressionTypes.get(node.getValue());
List<Expression> expressionValues = toExpressions(values, types);
List<Expression> simplifiedExpressionValues = Stream.concat(
expressionValues.stream()
.filter(DeterminismEvaluator::isDeterministic)
.distinct(),
expressionValues.stream()
.filter((expression -> !DeterminismEvaluator.isDeterministic(expression))))
.collect(toImmutableList());
return new InPredicate(toExpression(value, type), new InListExpression(simplifiedExpressionValues));
}
if (hasNullValue) {
return null;
}
return false;
}
@Override
protected Object visitExists(ExistsPredicate node, Object context)
{
if (!optimize) {
throw new UnsupportedOperationException("Exists subquery not yet implemented");
}
return node;
}
@Override
protected Object visitSubqueryExpression(SubqueryExpression node, Object context)
{
if (!optimize) {
throw new UnsupportedOperationException("Subquery not yet implemented");
}
return node;
}
@Override
protected Object visitArithmeticUnary(ArithmeticUnaryExpression node, Object context)
{
Object value = process(node.getValue(), context);
if (value == null) {
return null;
}
if (value instanceof Expression) {
return new ArithmeticUnaryExpression(node.getSign(), toExpression(value, expressionTypes.get(node.getValue())));
}
switch (node.getSign()) {
case PLUS:
return value;
case MINUS:
Signature operatorSignature = metadata.getFunctionRegistry().resolveOperator(OperatorType.NEGATION, types(node.getValue()));
MethodHandle handle = metadata.getFunctionRegistry().getScalarFunctionImplementation(operatorSignature).getMethodHandle();
if (handle.type().parameterCount() > 0 && handle.type().parameterType(0) == ConnectorSession.class) {
handle = handle.bindTo(session);
}
try {
return handle.invokeWithArguments(value);
}
catch (Throwable throwable) {
Throwables.propagateIfInstanceOf(throwable, RuntimeException.class);
Throwables.propagateIfInstanceOf(throwable, Error.class);
throw new RuntimeException(throwable.getMessage(), throwable);
}
}
throw new UnsupportedOperationException("Unsupported unary operator: " + node.getSign());
}
@Override
protected Object visitArithmeticBinary(ArithmeticBinaryExpression node, Object context)
{
Object left = process(node.getLeft(), context);
if (left == null) {
return null;
}
Object right = process(node.getRight(), context);
if (right == null) {
return null;
}
if (hasUnresolvedValue(left, right)) {
return new ArithmeticBinaryExpression(node.getType(), toExpression(left, expressionTypes.get(node.getLeft())), toExpression(right, expressionTypes.get(node.getRight())));
}
return invokeOperator(OperatorType.valueOf(node.getType().name()), types(node.getLeft(), node.getRight()), ImmutableList.of(left, right));
}
@Override
protected Object visitComparisonExpression(ComparisonExpression node, Object context)
{
ComparisonExpressionType type = node.getType();
Object left = process(node.getLeft(), context);
if (left == null && type != ComparisonExpressionType.IS_DISTINCT_FROM) {
return null;
}
Object right = process(node.getRight(), context);
if (type == ComparisonExpressionType.IS_DISTINCT_FROM) {
if (left == null && right == null) {
return false;
}
else if (left == null || right == null) {
return true;
}
}
else if (right == null) {
return null;
}
if (hasUnresolvedValue(left, right)) {
return new ComparisonExpression(type, toExpression(left, expressionTypes.get(node.getLeft())), toExpression(right, expressionTypes.get(node.getRight())));
}
return invokeOperator(OperatorType.valueOf(type.name()), types(node.getLeft(), node.getRight()), ImmutableList.of(left, right));
}
@Override
protected Object visitBetweenPredicate(BetweenPredicate node, Object context)
{
Object value = process(node.getValue(), context);
if (value == null) {
return null;
}
Object min = process(node.getMin(), context);
if (min == null) {
return null;
}
Object max = process(node.getMax(), context);
if (max == null) {
return null;
}
if (hasUnresolvedValue(value, min, max)) {
return new BetweenPredicate(
toExpression(value, expressionTypes.get(node.getValue())),
toExpression(min, expressionTypes.get(node.getMin())),
toExpression(max, expressionTypes.get(node.getMax())));
}
return invokeOperator(OperatorType.BETWEEN, types(node.getValue(), node.getMin(), node.getMax()), ImmutableList.of(value, min, max));
}
@Override
protected Object visitNullIfExpression(NullIfExpression node, Object context)
{
Object first = process(node.getFirst(), context);
if (first == null) {
return null;
}
Object second = process(node.getSecond(), context);
if (second == null) {
return first;
}
Type firstType = expressionTypes.get(node.getFirst());
Type secondType = expressionTypes.get(node.getSecond());
if (hasUnresolvedValue(first, second)) {
return new NullIfExpression(toExpression(first, firstType), toExpression(second, secondType));
}
Type commonType = metadata.getTypeManager().getCommonSuperType(firstType, secondType).get();
Signature firstCast = metadata.getFunctionRegistry().getCoercion(firstType, commonType);
Signature secondCast = metadata.getFunctionRegistry().getCoercion(secondType, commonType);
ScalarFunctionImplementation firstCastFunction = metadata.getFunctionRegistry().getScalarFunctionImplementation(firstCast);
ScalarFunctionImplementation secondCastFunction = metadata.getFunctionRegistry().getScalarFunctionImplementation(secondCast);
// cast(first as <common type>) == cast(second as <common type>)
boolean equal = (Boolean) invokeOperator(
OperatorType.EQUAL,
ImmutableList.of(commonType, commonType),
ImmutableList.of(
invoke(session, firstCastFunction, ImmutableList.of(first)),
invoke(session, secondCastFunction, ImmutableList.of(second))));
if (equal) {
return null;
}
else {
return first;
}
}
@Override
protected Object visitNotExpression(NotExpression node, Object context)
{
Object value = process(node.getValue(), context);
if (value == null) {
return null;
}
if (value instanceof Expression) {
return new NotExpression(toExpression(value, expressionTypes.get(node.getValue())));
}
return !(Boolean) value;
}
@Override
protected Object visitLogicalBinaryExpression(LogicalBinaryExpression node, Object context)
{
Object left = process(node.getLeft(), context);
Object right = process(node.getRight(), context);
switch (node.getType()) {
case AND: {
// if either left or right is false, result is always false regardless of nulls
if (Boolean.FALSE.equals(left) || Boolean.TRUE.equals(right)) {
return left;
}
if (Boolean.FALSE.equals(right) || Boolean.TRUE.equals(left)) {
return right;
}
break;
}
case OR: {
// if either left or right is true, result is always true regardless of nulls
if (Boolean.TRUE.equals(left) || Boolean.FALSE.equals(right)) {
return left;
}
if (Boolean.TRUE.equals(right) || Boolean.FALSE.equals(left)) {
return right;
}
break;
}
}
if (left == null && right == null) {
return null;
}
return new LogicalBinaryExpression(node.getType(),
toExpression(left, expressionTypes.get(node.getLeft())),
toExpression(right, expressionTypes.get(node.getRight())));
}
@Override
protected Object visitBooleanLiteral(BooleanLiteral node, Object context)
{
return node.equals(BooleanLiteral.TRUE_LITERAL);
}
@Override
protected Object visitFunctionCall(FunctionCall node, Object context)
{
List<Type> argumentTypes = new ArrayList<>();
List<Object> argumentValues = new ArrayList<>();
for (Expression expression : node.getArguments()) {
Object value = process(expression, context);
Type type = expressionTypes.get(expression);
argumentValues.add(value);
argumentTypes.add(type);
}
Signature functionSignature = metadata.getFunctionRegistry().resolveFunction(node.getName(), fromTypes(argumentTypes));
ScalarFunctionImplementation function = metadata.getFunctionRegistry().getScalarFunctionImplementation(functionSignature);
for (int i = 0; i < argumentValues.size(); i++) {
Object value = argumentValues.get(i);
if (value == null && !function.getNullableArguments().get(i)) {
return null;
}
}
// do not optimize non-deterministic functions
if (optimize && (!function.isDeterministic() || hasUnresolvedValue(argumentValues))) {
return new FunctionCall(node.getName(), node.getWindow(), node.isDistinct(), toExpressions(argumentValues, argumentTypes));
}
return invoke(session, function, argumentValues);
}
@Override
protected Object visitLambdaExpression(LambdaExpression node, Object context)
{
if (optimize) {
// TODO: enable optimization related to lambda expression
// A mechanism to convert function type back into lambda expression need to exist to enable optimization
return node;
}
Expression body = node.getBody();
FunctionType functionType = (FunctionType) expressionTypes.get(node);
List<Class<?>> argumentTypes =
Stream.concat(
Stream.of(ConnectorSession.class),
functionType.getArgumentTypes().stream()
.map(Type::getJavaType)
.map(Primitives::wrap))
.collect(toImmutableList());
List<String> argumentNames =
Stream.concat(
Stream.of("$connector_session"),
node.getArguments().stream()
.map(LambdaArgumentDeclaration::getName))
.collect(toImmutableList());
checkArgument(argumentTypes.size() == argumentNames.size());
return generateVarArgsToMapAdapter(
Primitives.wrap(functionType.getReturnType().getJavaType()),
argumentTypes,
argumentNames,
map -> process(body, new LambdaSymbolResolver(map)));
}
@Override
protected Object visitBindExpression(BindExpression node, Object context)
{
Object value = process(node.getValue(), context);
Object function = process(node.getFunction(), context);
if (hasUnresolvedValue(value, function)) {
return new BindExpression(
toExpression(value, expressionTypes.get(node.getValue())),
toExpression(function, expressionTypes.get(node.getFunction())));
}
return MethodHandles.insertArguments((MethodHandle) function, 1, value);
}
@Override
protected Object visitLikePredicate(LikePredicate node, Object context)
{
Object value = process(node.getValue(), context);
if (value == null) {
return null;
}
if (value instanceof Slice &&
node.getPattern() instanceof StringLiteral &&
(node.getEscape() instanceof StringLiteral || node.getEscape() == null)) {
// fast path when we know the pattern and escape are constant
return LikeFunctions.like((Slice) value, getConstantPattern(node));
}
Object pattern = process(node.getPattern(), context);
if (pattern == null) {
return null;
}
Object escape = null;
if (node.getEscape() != null) {
escape = process(node.getEscape(), context);
if (escape == null) {
return null;
}
}
if (value instanceof Slice &&
pattern instanceof Slice &&
(escape == null || escape instanceof Slice)) {
Regex regex;
if (escape == null) {
regex = LikeFunctions.likePattern((Slice) pattern);
}
else {
regex = LikeFunctions.likePattern((Slice) pattern, (Slice) escape);
}
return LikeFunctions.like((Slice) value, regex);
}
// if pattern is a constant without % or _ replace with a comparison
if (pattern instanceof Slice && (escape == null || escape instanceof Slice) && !isLikePattern((Slice) pattern, (Slice) escape)) {
Slice unescapedPattern = unescapeLiteralLikePattern((Slice) pattern, (Slice) escape);
Type valueType = expressionTypes.get(node.getValue());
Type patternType = createVarcharType(unescapedPattern.length());
TypeManager typeManager = metadata.getTypeManager();
Optional<Type> commonSuperType = typeManager.getCommonSuperType(valueType, patternType);
checkArgument(commonSuperType.isPresent(), "Missing super type when optimizing %s", node);
Expression valueExpression = toExpression(value, valueType);
Expression patternExpression = toExpression(unescapedPattern, patternType);
Type superType = commonSuperType.get();
if (!valueType.equals(superType)) {
valueExpression = new Cast(valueExpression, superType.getTypeSignature().toString(), false, typeManager.isTypeOnlyCoercion(valueType, superType));
}
if (!patternType.equals(superType)) {
patternExpression = new Cast(patternExpression, superType.getTypeSignature().toString(), false, typeManager.isTypeOnlyCoercion(patternType, superType));
}
return new ComparisonExpression(ComparisonExpressionType.EQUAL, valueExpression, patternExpression);
}
Expression optimizedEscape = null;
if (node.getEscape() != null) {
optimizedEscape = toExpression(escape, expressionTypes.get(node.getEscape()));
}
return new LikePredicate(
toExpression(value, expressionTypes.get(node.getValue())),
toExpression(pattern, expressionTypes.get(node.getPattern())),
optimizedEscape);
}
private Regex getConstantPattern(LikePredicate node)
{
Regex result = likePatternCache.get(node);
if (result == null) {
StringLiteral pattern = (StringLiteral) node.getPattern();
StringLiteral escape = (StringLiteral) node.getEscape();
if (escape == null) {
result = LikeFunctions.likePattern(pattern.getSlice());
}
else {
result = LikeFunctions.likePattern(pattern.getSlice(), escape.getSlice());
}
likePatternCache.put(node, result);
}
return result;
}
@Override
protected Object visitTryExpression(TryExpression node, Object context)
{
try {
Object innerExpression = process(node.getInnerExpression(), context);
if (innerExpression instanceof Expression) {
return new TryExpression((Expression) innerExpression);
}
return innerExpression;
}
catch (PrestoException e) {
tryExpressionExceptionHandler(e);
}
return null;
}
@Override
public Object visitCast(Cast node, Object context)
{
Object value = process(node.getExpression(), context);
if (value instanceof Expression) {
return new Cast((Expression) value, node.getType(), node.isSafe(), node.isTypeOnly());
}
if (node.isTypeOnly()) {
return value;
}
// hack!!! don't optimize CASTs for types that cannot be represented in the SQL AST
// TODO: this will not be an issue when we migrate to RowExpression tree for this, which allows arbitrary literals.
if (optimize && !FunctionRegistry.isSupportedLiteralType(expressionTypes.get(node))) {
return new Cast(toExpression(value, expressionTypes.get(node.getExpression())), node.getType(), node.isSafe(), node.isTypeOnly());
}
if (value == null) {
return null;
}
Type type = metadata.getType(parseTypeSignature(node.getType()));
if (type == null) {
throw new IllegalArgumentException("Unsupported type: " + node.getType());
}
Signature operator = metadata.getFunctionRegistry().getCoercion(expressionTypes.get(node.getExpression()), type);
try {
return invoke(session, metadata.getFunctionRegistry().getScalarFunctionImplementation(operator), ImmutableList.of(value));
}
catch (RuntimeException e) {
if (node.isSafe()) {
return null;
}
throw e;
}
}
@Override
protected Object visitArrayConstructor(ArrayConstructor node, Object context)
{
Type elementType = ((ArrayType) expressionTypes.get(node)).getElementType();
BlockBuilder arrayBlockBuilder = elementType.createBlockBuilder(new BlockBuilderStatus(), node.getValues().size());
for (Expression expression : node.getValues()) {
Object value = process(expression, context);
if (value instanceof Expression) {
return visitFunctionCall(new FunctionCall(QualifiedName.of(ArrayConstructor.ARRAY_CONSTRUCTOR), node.getValues()), context);
}
writeNativeValue(elementType, arrayBlockBuilder, value);
}
return arrayBlockBuilder.build();
}
@Override
protected Object visitRow(Row node, Object context)
{
RowType rowType = (RowType) expressionTypes.get(node);
List<Type> parameterTypes = rowType.getTypeParameters();
List<Expression> arguments = node.getItems();
int cardinality = arguments.size();
List<Object> values = new ArrayList<>(cardinality);
for (Expression argument : arguments) {
values.add(process(argument, context));
}
if (hasUnresolvedValue(values)) {
return new Row(toExpressions(values, parameterTypes));
}
else {
BlockBuilder blockBuilder = new InterleavedBlockBuilder(parameterTypes, new BlockBuilderStatus(), cardinality);
for (int i = 0; i < cardinality; ++i) {
writeNativeValue(parameterTypes.get(i), blockBuilder, values.get(i));
}
return blockBuilder.build();
}
}
@Override
protected Object visitSubscriptExpression(SubscriptExpression node, Object context)
{
Object base = process(node.getBase(), context);
if (base == null) {
return null;
}
Object index = process(node.getIndex(), context);
if (index == null) {
return null;
}
if ((index instanceof Long) && isArray(expressionTypes.get(node.getBase()))) {
ArraySubscriptOperator.checkArrayIndex((Long) index);
}
if (hasUnresolvedValue(base, index)) {
return new SubscriptExpression(toExpression(base, expressionTypes.get(node.getBase())), toExpression(index, expressionTypes.get(node.getIndex())));
}
return invokeOperator(OperatorType.SUBSCRIPT, types(node.getBase(), node.getIndex()), ImmutableList.of(base, index));
}
@Override
protected Object visitQuantifiedComparisonExpression(QuantifiedComparisonExpression node, Object context)
{
if (!optimize) {
throw new UnsupportedOperationException("QuantifiedComparison not yet implemented");
}
return node;
}
@Override
protected Object visitExpression(Expression node, Object context)
{
throw new PrestoException(NOT_SUPPORTED, "not yet implemented: " + node.getClass().getName());
}
@Override
protected Object visitNode(Node node, Object context)
{
throw new UnsupportedOperationException("Evaluator visitor can only handle Expression nodes");
}
private List<Type> types(Expression... types)
{
return ImmutableList.copyOf(Iterables.transform(ImmutableList.copyOf(types), Functions.forMap(expressionTypes)));
}
private boolean hasUnresolvedValue(Object... values)
{
return hasUnresolvedValue(ImmutableList.copyOf(values));
}
private boolean hasUnresolvedValue(List<Object> values)
{
return any(values, instanceOf(Expression.class));
}
private Object invokeOperator(OperatorType operatorType, List<? extends Type> argumentTypes, List<Object> argumentValues)
{
Signature operatorSignature = metadata.getFunctionRegistry().resolveOperator(operatorType, argumentTypes);
return invoke(session, metadata.getFunctionRegistry().getScalarFunctionImplementation(operatorSignature), argumentValues);
}
}
private interface PagePositionContext
{
public Block getBlock(int channel);
public int getPosition(int channel);
}
private static class SinglePagePositionContext
implements PagePositionContext
{
private final int position;
private final Block[] blocks;
private SinglePagePositionContext(int position, Block[] blocks)
{
this.position = position;
this.blocks = blocks;
}
@Override
public Block getBlock(int channel)
{
return blocks[channel];
}
@Override
public int getPosition(int channel)
{
return position;
}
}
private static class TwoPagesPositionContext
implements PagePositionContext
{
private final int leftPosition;
private final int rightPosition;
private final Block[] leftBlocks;
private final Block[] rightBlocks;
private TwoPagesPositionContext(int leftPosition, Block[] leftBlocks, int rightPosition, Block[] rightBlocks)
{
this.leftPosition = leftPosition;
this.rightPosition = rightPosition;
this.leftBlocks = leftBlocks;
this.rightBlocks = rightBlocks;
}
@Override
public Block getBlock(int channel)
{
if (channel < leftBlocks.length) {
return leftBlocks[channel];
}
else {
return rightBlocks[channel - leftBlocks.length];
}
}
@Override
public int getPosition(int channel)
{
if (channel < leftBlocks.length) {
return leftPosition;
}
else {
return rightPosition;
}
}
}
public static Object invoke(ConnectorSession session, ScalarFunctionImplementation function, List<Object> argumentValues)
{
MethodHandle handle = function.getMethodHandle();
if (function.getInstanceFactory().isPresent()) {
try {
handle = handle.bindTo(function.getInstanceFactory().get().invoke());
}
catch (Throwable throwable) {
if (throwable instanceof InterruptedException) {
Thread.currentThread().interrupt();
}
throw Throwables.propagate(throwable);
}
}
if (handle.type().parameterCount() > 0 && handle.type().parameterType(0) == ConnectorSession.class) {
handle = handle.bindTo(session);
}
try {
List<Object> actualArguments = new ArrayList<>();
Class<?>[] parameterArray = handle.type().parameterArray();
for (int i = 0; i < argumentValues.size(); i++) {
Object argument = argumentValues.get(i);
if (function.getNullFlags().get(i)) {
boolean isNull = argument == null;
if (isNull) {
argument = Defaults.defaultValue(parameterArray[actualArguments.size()]);
}
actualArguments.add(argument);
actualArguments.add(isNull);
}
else {
actualArguments.add(argument);
}
}
return handle.invokeWithArguments(actualArguments);
}
catch (Throwable throwable) {
if (throwable instanceof InterruptedException) {
Thread.currentThread().interrupt();
}
throw Throwables.propagate(throwable);
}
}
@VisibleForTesting
public static Expression createFailureFunction(RuntimeException exception, Type type)
{
requireNonNull(exception, "Exception is null");
String failureInfo = JsonCodec.jsonCodec(FailureInfo.class).toJson(Failures.toFailure(exception).toFailureInfo());
FunctionCall jsonParse = new FunctionCall(QualifiedName.of("json_parse"), ImmutableList.of(new StringLiteral(failureInfo)));
FunctionCall failureFunction = new FunctionCall(QualifiedName.of("fail"), ImmutableList.of(jsonParse));
return new Cast(failureFunction, type.getTypeSignature().toString());
}
private static boolean isArray(Type type)
{
return type.getTypeSignature().getBase().equals(StandardTypes.ARRAY);
}
private static class LambdaSymbolResolver
implements SymbolResolver
{
private final Map<String, Object> values;
public LambdaSymbolResolver(Map<String, Object> values)
{
this.values = requireNonNull(values, "values is null");
}
@Override
public Object getValue(Symbol symbol)
{
checkState(values.containsKey(symbol.getName()), "values does not contain %s", symbol);
return values.get(symbol.getName());
}
}
}