/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.relational;
import com.facebook.presto.Session;
import com.facebook.presto.metadata.FunctionKind;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.spi.type.DecimalParseResult;
import com.facebook.presto.spi.type.Decimals;
import com.facebook.presto.spi.type.TimeZoneKey;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeManager;
import com.facebook.presto.spi.type.TypeSignature;
import com.facebook.presto.sql.relational.optimizer.ExpressionOptimizer;
import com.facebook.presto.sql.tree.ArithmeticBinaryExpression;
import com.facebook.presto.sql.tree.ArithmeticUnaryExpression;
import com.facebook.presto.sql.tree.ArrayConstructor;
import com.facebook.presto.sql.tree.AstVisitor;
import com.facebook.presto.sql.tree.BetweenPredicate;
import com.facebook.presto.sql.tree.BinaryLiteral;
import com.facebook.presto.sql.tree.BindExpression;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.CharLiteral;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.DecimalLiteral;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.DoubleLiteral;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FieldReference;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.GenericLiteral;
import com.facebook.presto.sql.tree.IfExpression;
import com.facebook.presto.sql.tree.InListExpression;
import com.facebook.presto.sql.tree.InPredicate;
import com.facebook.presto.sql.tree.IntervalLiteral;
import com.facebook.presto.sql.tree.IsNotNullPredicate;
import com.facebook.presto.sql.tree.IsNullPredicate;
import com.facebook.presto.sql.tree.LambdaArgumentDeclaration;
import com.facebook.presto.sql.tree.LambdaExpression;
import com.facebook.presto.sql.tree.LikePredicate;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullIfExpression;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.Row;
import com.facebook.presto.sql.tree.SearchedCaseExpression;
import com.facebook.presto.sql.tree.SimpleCaseExpression;
import com.facebook.presto.sql.tree.StringLiteral;
import com.facebook.presto.sql.tree.SubscriptExpression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.sql.tree.TimeLiteral;
import com.facebook.presto.sql.tree.TimestampLiteral;
import com.facebook.presto.sql.tree.TryExpression;
import com.facebook.presto.sql.tree.WhenClause;
import com.facebook.presto.type.RowType;
import com.facebook.presto.type.RowType.RowField;
import com.facebook.presto.type.UnknownType;
import com.facebook.presto.util.maps.IdentityLinkedHashMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.List;
import static com.facebook.presto.metadata.FunctionKind.SCALAR;
import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.BooleanType.BOOLEAN;
import static com.facebook.presto.spi.type.CharType.createCharType;
import static com.facebook.presto.spi.type.DoubleType.DOUBLE;
import static com.facebook.presto.spi.type.IntegerType.INTEGER;
import static com.facebook.presto.spi.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE;
import static com.facebook.presto.spi.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE;
import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY;
import static com.facebook.presto.spi.type.VarcharType.VARCHAR;
import static com.facebook.presto.spi.type.VarcharType.createVarcharType;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.sql.relational.Expressions.constantNull;
import static com.facebook.presto.sql.relational.Expressions.field;
import static com.facebook.presto.sql.relational.Signatures.arithmeticExpressionSignature;
import static com.facebook.presto.sql.relational.Signatures.arithmeticNegationSignature;
import static com.facebook.presto.sql.relational.Signatures.arrayConstructorSignature;
import static com.facebook.presto.sql.relational.Signatures.betweenSignature;
import static com.facebook.presto.sql.relational.Signatures.bindSignature;
import static com.facebook.presto.sql.relational.Signatures.castSignature;
import static com.facebook.presto.sql.relational.Signatures.coalesceSignature;
import static com.facebook.presto.sql.relational.Signatures.comparisonExpressionSignature;
import static com.facebook.presto.sql.relational.Signatures.dereferenceSignature;
import static com.facebook.presto.sql.relational.Signatures.likePatternSignature;
import static com.facebook.presto.sql.relational.Signatures.likeSignature;
import static com.facebook.presto.sql.relational.Signatures.logicalExpressionSignature;
import static com.facebook.presto.sql.relational.Signatures.nullIfSignature;
import static com.facebook.presto.sql.relational.Signatures.rowConstructorSignature;
import static com.facebook.presto.sql.relational.Signatures.subscriptSignature;
import static com.facebook.presto.sql.relational.Signatures.switchSignature;
import static com.facebook.presto.sql.relational.Signatures.tryCastSignature;
import static com.facebook.presto.sql.relational.Signatures.whenSignature;
import static com.facebook.presto.type.JsonType.JSON;
import static com.facebook.presto.type.LikePatternType.LIKE_PATTERN;
import static com.facebook.presto.util.DateTimeUtils.parseDayTimeInterval;
import static com.facebook.presto.util.DateTimeUtils.parseTimeWithTimeZone;
import static com.facebook.presto.util.DateTimeUtils.parseTimeWithoutTimeZone;
import static com.facebook.presto.util.DateTimeUtils.parseTimestampWithTimeZone;
import static com.facebook.presto.util.DateTimeUtils.parseTimestampWithoutTimeZone;
import static com.facebook.presto.util.DateTimeUtils.parseYearMonthInterval;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.slice.SliceUtf8.countCodePoints;
import static io.airlift.slice.Slices.utf8Slice;
import static java.util.Objects.requireNonNull;
public final class SqlToRowExpressionTranslator
{
private SqlToRowExpressionTranslator() {}
public static RowExpression translate(
Expression expression,
FunctionKind functionKind,
IdentityLinkedHashMap<Expression, Type> types,
FunctionRegistry functionRegistry,
TypeManager typeManager,
Session session,
boolean optimize)
{
RowExpression result = new Visitor(functionKind, types, typeManager, session.getTimeZoneKey()).process(expression, null);
requireNonNull(result, "translated expression is null");
if (optimize) {
ExpressionOptimizer optimizer = new ExpressionOptimizer(functionRegistry, typeManager, session);
return optimizer.optimize(result);
}
return result;
}
private static class Visitor
extends AstVisitor<RowExpression, Void>
{
private final FunctionKind functionKind;
private final IdentityLinkedHashMap<Expression, Type> types;
private final TypeManager typeManager;
private final TimeZoneKey timeZoneKey;
private Visitor(FunctionKind functionKind, IdentityLinkedHashMap<Expression, Type> types, TypeManager typeManager, TimeZoneKey timeZoneKey)
{
this.functionKind = functionKind;
this.types = types;
this.typeManager = typeManager;
this.timeZoneKey = timeZoneKey;
}
@Override
protected RowExpression visitExpression(Expression node, Void context)
{
throw new UnsupportedOperationException("not yet implemented: expression translator for " + node.getClass().getName());
}
@Override
protected RowExpression visitFieldReference(FieldReference node, Void context)
{
return field(node.getFieldIndex(), types.get(node));
}
@Override
protected RowExpression visitNullLiteral(NullLiteral node, Void context)
{
return constantNull(UnknownType.UNKNOWN);
}
@Override
protected RowExpression visitBooleanLiteral(BooleanLiteral node, Void context)
{
return constant(node.getValue(), BOOLEAN);
}
@Override
protected RowExpression visitLongLiteral(LongLiteral node, Void context)
{
if (node.getValue() >= Integer.MIN_VALUE && node.getValue() <= Integer.MAX_VALUE) {
return constant(node.getValue(), INTEGER);
}
return constant(node.getValue(), BIGINT);
}
@Override
protected RowExpression visitDoubleLiteral(DoubleLiteral node, Void context)
{
return constant(node.getValue(), DOUBLE);
}
@Override
protected RowExpression visitDecimalLiteral(DecimalLiteral node, Void context)
{
DecimalParseResult parseResult = Decimals.parse(node.getValue());
return constant(parseResult.getObject(), parseResult.getType());
}
@Override
protected RowExpression visitStringLiteral(StringLiteral node, Void context)
{
return constant(node.getSlice(), createVarcharType(countCodePoints(node.getSlice())));
}
@Override
protected RowExpression visitCharLiteral(CharLiteral node, Void context)
{
return constant(node.getSlice(), createCharType(node.getValue().length()));
}
@Override
protected RowExpression visitBinaryLiteral(BinaryLiteral node, Void context)
{
return constant(node.getValue(), VARBINARY);
}
@Override
protected RowExpression visitGenericLiteral(GenericLiteral node, Void context)
{
Type type = typeManager.getType(parseTypeSignature(node.getType()));
if (type == null) {
throw new IllegalArgumentException("Unsupported type: " + node.getType());
}
if (JSON.equals(type)) {
return call(
new Signature("json_parse", SCALAR, types.get(node).getTypeSignature(), VARCHAR.getTypeSignature()),
types.get(node),
constant(utf8Slice(node.getValue()), VARCHAR));
}
return call(
castSignature(types.get(node), VARCHAR),
types.get(node),
constant(utf8Slice(node.getValue()), VARCHAR));
}
@Override
protected RowExpression visitTimeLiteral(TimeLiteral node, Void context)
{
long value;
if (types.get(node).equals(TIME_WITH_TIME_ZONE)) {
value = parseTimeWithTimeZone(node.getValue());
}
else {
// parse in time zone of client
value = parseTimeWithoutTimeZone(timeZoneKey, node.getValue());
}
return constant(value, types.get(node));
}
@Override
protected RowExpression visitTimestampLiteral(TimestampLiteral node, Void context)
{
long value;
if (types.get(node).equals(TIMESTAMP_WITH_TIME_ZONE)) {
value = parseTimestampWithTimeZone(timeZoneKey, node.getValue());
}
else {
// parse in time zone of client
value = parseTimestampWithoutTimeZone(timeZoneKey, node.getValue());
}
return constant(value, types.get(node));
}
@Override
protected RowExpression visitIntervalLiteral(IntervalLiteral node, Void context)
{
long value;
if (node.isYearToMonth()) {
value = node.getSign().multiplier() * parseYearMonthInterval(node.getValue(), node.getStartField(), node.getEndField());
}
else {
value = node.getSign().multiplier() * parseDayTimeInterval(node.getValue(), node.getStartField(), node.getEndField());
}
return constant(value, types.get(node));
}
@Override
protected RowExpression visitComparisonExpression(ComparisonExpression node, Void context)
{
RowExpression left = process(node.getLeft(), context);
RowExpression right = process(node.getRight(), context);
return call(
comparisonExpressionSignature(node.getType(), left.getType(), right.getType()),
BOOLEAN,
left,
right);
}
@Override
protected RowExpression visitFunctionCall(FunctionCall node, Void context)
{
List<RowExpression> arguments = node.getArguments().stream()
.map(value -> process(value, context))
.collect(toImmutableList());
List<TypeSignature> argumentTypes = arguments.stream()
.map(RowExpression::getType)
.map(Type::getTypeSignature)
.collect(toImmutableList());
Signature signature = new Signature(node.getName().getSuffix(), functionKind, types.get(node).getTypeSignature(), argumentTypes);
return call(signature, types.get(node), arguments);
}
@Override
protected RowExpression visitSymbolReference(SymbolReference node, Void context)
{
return new VariableReferenceExpression(node.getName(), types.get(node));
}
@Override
protected RowExpression visitLambdaExpression(LambdaExpression node, Void context)
{
RowExpression body = process(node.getBody(), context);
Type type = types.get(node);
List<Type> typeParameters = type.getTypeParameters();
List<Type> argumentTypes = typeParameters.subList(0, typeParameters.size() - 1);
List<String> argumentNames = node.getArguments().stream()
.map(LambdaArgumentDeclaration::getName)
.collect(toImmutableList());
return new LambdaDefinitionExpression(argumentTypes, argumentNames, body);
}
@Override
protected RowExpression visitBindExpression(BindExpression node, Void context)
{
RowExpression value = process(node.getValue(), context);
RowExpression function = process(node.getFunction(), context);
return call(
bindSignature(types.get(node), value.getType(), function.getType()),
types.get(node),
value,
function);
}
@Override
protected RowExpression visitArithmeticBinary(ArithmeticBinaryExpression node, Void context)
{
RowExpression left = process(node.getLeft(), context);
RowExpression right = process(node.getRight(), context);
return call(
arithmeticExpressionSignature(node.getType(), types.get(node), left.getType(), right.getType()),
types.get(node),
left,
right);
}
@Override
protected RowExpression visitArithmeticUnary(ArithmeticUnaryExpression node, Void context)
{
RowExpression expression = process(node.getValue(), context);
switch (node.getSign()) {
case PLUS:
return expression;
case MINUS:
return call(
arithmeticNegationSignature(types.get(node), expression.getType()),
types.get(node),
expression);
}
throw new UnsupportedOperationException("Unsupported unary operator: " + node.getSign());
}
@Override
protected RowExpression visitLogicalBinaryExpression(LogicalBinaryExpression node, Void context)
{
return call(
logicalExpressionSignature(node.getType()),
BOOLEAN,
process(node.getLeft(), context),
process(node.getRight(), context));
}
@Override
protected RowExpression visitCast(Cast node, Void context)
{
RowExpression value = process(node.getExpression(), context);
if (node.isTypeOnly()) {
return changeType(value, types.get(node));
}
if (node.isSafe()) {
return call(tryCastSignature(types.get(node), value.getType()), types.get(node), value);
}
return call(castSignature(types.get(node), value.getType()), types.get(node), value);
}
private static RowExpression changeType(RowExpression value, Type targetType)
{
ChangeTypeVisitor visitor = new ChangeTypeVisitor(targetType);
return value.accept(visitor, null);
}
private static class ChangeTypeVisitor
implements RowExpressionVisitor<Void, RowExpression>
{
private final Type targetType;
private ChangeTypeVisitor(Type targetType)
{
this.targetType = targetType;
}
@Override
public RowExpression visitCall(CallExpression call, Void context)
{
return new CallExpression(call.getSignature(), targetType, call.getArguments());
}
@Override
public RowExpression visitInputReference(InputReferenceExpression reference, Void context)
{
return field(reference.getField(), targetType);
}
@Override
public RowExpression visitConstant(ConstantExpression literal, Void context)
{
return constant(literal.getValue(), targetType);
}
@Override
public RowExpression visitLambda(LambdaDefinitionExpression lambda, Void context)
{
throw new UnsupportedOperationException();
}
@Override
public RowExpression visitVariableReference(VariableReferenceExpression reference, Void context)
{
return new VariableReferenceExpression(reference.getName(), targetType);
}
}
@Override
protected RowExpression visitCoalesceExpression(CoalesceExpression node, Void context)
{
List<RowExpression> arguments = node.getOperands().stream()
.map(value -> process(value, context))
.collect(toImmutableList());
List<Type> argumentTypes = arguments.stream().map(RowExpression::getType).collect(toImmutableList());
return call(coalesceSignature(types.get(node), argumentTypes), types.get(node), arguments);
}
@Override
protected RowExpression visitSimpleCaseExpression(SimpleCaseExpression node, Void context)
{
ImmutableList.Builder<RowExpression> arguments = ImmutableList.builder();
arguments.add(process(node.getOperand(), context));
for (WhenClause clause : node.getWhenClauses()) {
arguments.add(call(whenSignature(types.get(clause)),
types.get(clause),
process(clause.getOperand(), context),
process(clause.getResult(), context)));
}
Type returnType = types.get(node);
arguments.add(node.getDefaultValue()
.map((value) -> process(value, context))
.orElse(constantNull(returnType)));
return call(switchSignature(returnType), returnType, arguments.build());
}
@Override
protected RowExpression visitSearchedCaseExpression(SearchedCaseExpression node, Void context)
{
/*
Translates an expression like:
case when cond1 then value1
when cond2 then value2
when cond3 then value3
else value4
end
To:
IF(cond1,
value1,
IF(cond2,
value2,
If(cond3,
value3,
value4)))
*/
RowExpression expression = node.getDefaultValue()
.map((value) -> process(value, context))
.orElse(constantNull(types.get(node)));
for (WhenClause clause : Lists.reverse(node.getWhenClauses())) {
expression = call(
Signatures.ifSignature(types.get(node)),
types.get(node),
process(clause.getOperand(), context),
process(clause.getResult(), context),
expression);
}
return expression;
}
@Override
protected RowExpression visitDereferenceExpression(DereferenceExpression node, Void context)
{
RowType rowType = (RowType) types.get(node.getBase());
List<RowField> fields = rowType.getFields();
int index = -1;
for (int i = 0; i < fields.size(); i++) {
RowField field = fields.get(i);
if (field.getName().isPresent() && field.getName().get().equalsIgnoreCase(node.getFieldName())) {
checkArgument(index < 0, "Ambiguous field %s in type %s", field, rowType.getDisplayName());
index = i;
}
}
checkState(index >= 0, "could not find field name: %s", node.getFieldName());
Type returnType = types.get(node);
return call(dereferenceSignature(returnType, rowType), returnType, process(node.getBase(), context), constant(index, INTEGER));
}
@Override
protected RowExpression visitIfExpression(IfExpression node, Void context)
{
ImmutableList.Builder<RowExpression> arguments = ImmutableList.builder();
arguments.add(process(node.getCondition(), context))
.add(process(node.getTrueValue(), context));
if (node.getFalseValue().isPresent()) {
arguments.add(process(node.getFalseValue().get(), context));
}
else {
arguments.add(constantNull(types.get(node)));
}
return call(Signatures.ifSignature(types.get(node)), types.get(node), arguments.build());
}
@Override
protected RowExpression visitTryExpression(TryExpression node, Void context)
{
return call(Signatures.trySignature(types.get(node)), types.get(node), process(node.getInnerExpression(), context));
}
@Override
protected RowExpression visitInPredicate(InPredicate node, Void context)
{
ImmutableList.Builder<RowExpression> arguments = ImmutableList.builder();
arguments.add(process(node.getValue(), context));
InListExpression values = (InListExpression) node.getValueList();
for (Expression value : values.getValues()) {
arguments.add(process(value, context));
}
return call(Signatures.inSignature(), BOOLEAN, arguments.build());
}
@Override
protected RowExpression visitIsNotNullPredicate(IsNotNullPredicate node, Void context)
{
RowExpression expression = process(node.getValue(), context);
return call(
Signatures.notSignature(),
BOOLEAN,
call(Signatures.isNullSignature(expression.getType()), BOOLEAN, ImmutableList.of(expression)));
}
@Override
protected RowExpression visitIsNullPredicate(IsNullPredicate node, Void context)
{
RowExpression expression = process(node.getValue(), context);
return call(Signatures.isNullSignature(expression.getType()), BOOLEAN, expression);
}
@Override
protected RowExpression visitNotExpression(NotExpression node, Void context)
{
return call(Signatures.notSignature(), BOOLEAN, process(node.getValue(), context));
}
@Override
protected RowExpression visitNullIfExpression(NullIfExpression node, Void context)
{
RowExpression first = process(node.getFirst(), context);
RowExpression second = process(node.getSecond(), context);
return call(
nullIfSignature(types.get(node), first.getType(), second.getType()),
types.get(node),
first,
second);
}
@Override
protected RowExpression visitBetweenPredicate(BetweenPredicate node, Void context)
{
RowExpression value = process(node.getValue(), context);
RowExpression min = process(node.getMin(), context);
RowExpression max = process(node.getMax(), context);
return call(
betweenSignature(value.getType(), min.getType(), max.getType()),
BOOLEAN,
value,
min,
max);
}
@Override
protected RowExpression visitLikePredicate(LikePredicate node, Void context)
{
RowExpression value = process(node.getValue(), context);
RowExpression pattern = process(node.getPattern(), context);
if (node.getEscape() != null) {
RowExpression escape = process(node.getEscape(), context);
return call(likeSignature(), BOOLEAN, value, call(likePatternSignature(), LIKE_PATTERN, pattern, escape));
}
return call(likeSignature(), BOOLEAN, value, call(castSignature(LIKE_PATTERN, VARCHAR), LIKE_PATTERN, pattern));
}
@Override
protected RowExpression visitSubscriptExpression(SubscriptExpression node, Void context)
{
RowExpression base = process(node.getBase(), context);
RowExpression index = process(node.getIndex(), context);
return call(
subscriptSignature(types.get(node), base.getType(), index.getType()),
types.get(node),
base,
index);
}
@Override
protected RowExpression visitArrayConstructor(ArrayConstructor node, Void context)
{
List<RowExpression> arguments = node.getValues().stream()
.map(value -> process(value, context))
.collect(toImmutableList());
List<Type> argumentTypes = arguments.stream()
.map(RowExpression::getType)
.collect(toImmutableList());
return call(arrayConstructorSignature(types.get(node), argumentTypes), types.get(node), arguments);
}
@Override
protected RowExpression visitRow(Row node, Void context)
{
List<RowExpression> arguments = node.getItems().stream()
.map(value -> process(value, context))
.collect(toImmutableList());
Type returnType = types.get(node);
List<Type> argumentTypes = node.getItems().stream()
.map(value -> types.get(value))
.collect(toImmutableList());
return call(rowConstructorSignature(returnType, argumentTypes), returnType, arguments);
}
}
}