/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.optimizations;
import com.facebook.presto.Session;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolToInputRewriter;
import com.facebook.presto.sql.relational.CallExpression;
import com.facebook.presto.sql.relational.ConstantExpression;
import com.facebook.presto.sql.relational.InputReferenceExpression;
import com.facebook.presto.sql.relational.LambdaDefinitionExpression;
import com.facebook.presto.sql.relational.RowExpression;
import com.facebook.presto.sql.relational.RowExpressionVisitor;
import com.facebook.presto.sql.relational.VariableReferenceExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.util.maps.IdentityLinkedHashMap;
import com.google.common.collect.ComparisonChain;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Ordering;
import io.airlift.slice.Slice;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import static com.facebook.presto.metadata.FunctionKind.SCALAR;
import static com.facebook.presto.metadata.FunctionRegistry.mangleOperatorName;
import static com.facebook.presto.metadata.Signature.internalScalarFunction;
import static com.facebook.presto.spi.function.OperatorType.EQUAL;
import static com.facebook.presto.spi.function.OperatorType.GREATER_THAN;
import static com.facebook.presto.spi.function.OperatorType.GREATER_THAN_OR_EQUAL;
import static com.facebook.presto.spi.function.OperatorType.IS_DISTINCT_FROM;
import static com.facebook.presto.spi.function.OperatorType.LESS_THAN;
import static com.facebook.presto.spi.function.OperatorType.LESS_THAN_OR_EQUAL;
import static com.facebook.presto.spi.function.OperatorType.NOT_EQUAL;
import static com.facebook.presto.spi.type.BooleanType.BOOLEAN;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypesFromInput;
import static com.facebook.presto.sql.relational.SqlToRowExpressionTranslator.translate;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.Integer.min;
import static java.util.Collections.emptyList;
import static java.util.Objects.requireNonNull;
public class ExpressionEquivalence
{
private static final Ordering<RowExpression> ROW_EXPRESSION_ORDERING = Ordering.from(new RowExpressionComparator());
private static final CanonicalizationVisitor CANONICALIZATION_VISITOR = new CanonicalizationVisitor();
private final Metadata metadata;
private final SqlParser sqlParser;
public ExpressionEquivalence(Metadata metadata, SqlParser sqlParser)
{
this.metadata = requireNonNull(metadata, "metadata is null");
this.sqlParser = requireNonNull(sqlParser, "sqlParser is null");
}
public boolean areExpressionsEquivalent(Session session, Expression leftExpression, Expression rightExpression, Map<Symbol, Type> types)
{
Map<Symbol, Integer> symbolInput = new HashMap<>();
Map<Integer, Type> inputTypes = new HashMap<>();
int inputId = 0;
for (Entry<Symbol, Type> entry : types.entrySet()) {
symbolInput.put(entry.getKey(), inputId);
inputTypes.put(inputId, entry.getValue());
inputId++;
}
RowExpression leftRowExpression = toRowExpression(session, leftExpression, symbolInput, inputTypes);
RowExpression rightRowExpression = toRowExpression(session, rightExpression, symbolInput, inputTypes);
RowExpression canonicalizedLeft = leftRowExpression.accept(CANONICALIZATION_VISITOR, null);
RowExpression canonicalizedRight = rightRowExpression.accept(CANONICALIZATION_VISITOR, null);
return canonicalizedLeft.equals(canonicalizedRight);
}
private RowExpression toRowExpression(Session session, Expression expression, Map<Symbol, Integer> symbolInput, Map<Integer, Type> inputTypes)
{
// replace qualified names with input references since row expressions do not support these
Expression expressionWithInputReferences = new SymbolToInputRewriter(symbolInput).rewrite(expression);
// determine the type of every expression
IdentityLinkedHashMap<Expression, Type> expressionTypes = getExpressionTypesFromInput(
session,
metadata,
sqlParser,
inputTypes,
expressionWithInputReferences,
emptyList() /* parameters have already been replaced */);
// convert to row expression
return translate(expressionWithInputReferences, SCALAR, expressionTypes, metadata.getFunctionRegistry(), metadata.getTypeManager(), session, false);
}
private static class CanonicalizationVisitor
implements RowExpressionVisitor<Void, RowExpression>
{
@Override
public RowExpression visitCall(CallExpression call, Void context)
{
call = new CallExpression(
call.getSignature(),
call.getType(),
call.getArguments().stream()
.map(expression -> expression.accept(this, context))
.collect(toImmutableList()));
String callName = call.getSignature().getName();
if (callName.equals("AND") || callName.equals("OR")) {
// if we have nested calls (of the same type) flatten them
List<RowExpression> flattenedArguments = flattenNestedCallArgs(call);
// only consider distinct arguments
Set<RowExpression> distinctArguments = ImmutableSet.copyOf(flattenedArguments);
if (distinctArguments.size() == 1) {
return Iterables.getOnlyElement(distinctArguments);
}
// canonicalize the argument order (i.e., sort them)
List<RowExpression> sortedArguments = ROW_EXPRESSION_ORDERING.sortedCopy(distinctArguments);
return new CallExpression(
internalScalarFunction(
callName,
BOOLEAN.getTypeSignature(),
distinctArguments.stream()
.map(RowExpression::getType)
.map(Type::getTypeSignature)
.collect(toImmutableList())),
BOOLEAN,
sortedArguments);
}
if (callName.equals(mangleOperatorName(EQUAL)) || callName.equals(mangleOperatorName(NOT_EQUAL)) || callName.equals(mangleOperatorName(IS_DISTINCT_FROM))) {
// sort arguments
return new CallExpression(
call.getSignature(),
call.getType(),
ROW_EXPRESSION_ORDERING.sortedCopy(call.getArguments()));
}
if (callName.equals(mangleOperatorName(GREATER_THAN)) || callName.equals(mangleOperatorName(GREATER_THAN_OR_EQUAL))) {
// convert greater than to less than
return new CallExpression(
new Signature(
callName.equals(mangleOperatorName(GREATER_THAN)) ? mangleOperatorName(LESS_THAN) : mangleOperatorName(LESS_THAN_OR_EQUAL),
SCALAR,
call.getSignature().getTypeVariableConstraints(),
call.getSignature().getLongVariableConstraints(),
call.getSignature().getReturnType(),
swapPair(call.getSignature().getArgumentTypes()),
false),
call.getType(),
swapPair(call.getArguments()));
}
return call;
}
public static List<RowExpression> flattenNestedCallArgs(CallExpression call)
{
String callName = call.getSignature().getName();
ImmutableList.Builder<RowExpression> newArguments = ImmutableList.builder();
for (RowExpression argument : call.getArguments()) {
if (argument instanceof CallExpression && callName.equals(((CallExpression) argument).getSignature().getName())) {
// same call type, so flatten the args
newArguments.addAll(flattenNestedCallArgs((CallExpression) argument));
}
else {
newArguments.add(argument);
}
}
return newArguments.build();
}
@Override
public RowExpression visitConstant(ConstantExpression constant, Void context)
{
return constant;
}
@Override
public RowExpression visitInputReference(InputReferenceExpression node, Void context)
{
return node;
}
@Override
public RowExpression visitLambda(LambdaDefinitionExpression lambda, Void context)
{
return new LambdaDefinitionExpression(lambda.getArgumentTypes(), lambda.getArguments(), lambda.getBody().accept(this, context));
}
@Override
public RowExpression visitVariableReference(VariableReferenceExpression reference, Void context)
{
return reference;
}
}
private static class RowExpressionComparator
implements Comparator<RowExpression>
{
private final Comparator<Object> classComparator = Ordering.arbitrary();
private final ListComparator<RowExpression> argumentComparator = new ListComparator<>(this);
@Override
public int compare(RowExpression left, RowExpression right)
{
int result = classComparator.compare(left.getClass(), right.getClass());
if (result != 0) {
return result;
}
if (left instanceof CallExpression) {
CallExpression leftCall = (CallExpression) left;
CallExpression rightCall = (CallExpression) right;
return ComparisonChain.start()
.compare(leftCall.getSignature().toString(), rightCall.getSignature().toString())
.compare(leftCall.getArguments(), rightCall.getArguments(), argumentComparator)
.result();
}
if (left instanceof ConstantExpression) {
ConstantExpression leftConstant = (ConstantExpression) left;
ConstantExpression rightConstant = (ConstantExpression) right;
result = leftConstant.getType().getTypeSignature().toString().compareTo(right.getType().getTypeSignature().toString());
if (result != 0) {
return result;
}
Object leftValue = leftConstant.getValue();
Object rightValue = rightConstant.getValue();
Class<?> javaType = leftConstant.getType().getJavaType();
if (javaType == boolean.class) {
return ((Boolean) leftValue).compareTo((Boolean) rightValue);
}
if (javaType == byte.class || javaType == short.class || javaType == int.class || javaType == long.class) {
return Long.compare(((Number) leftValue).longValue(), ((Number) rightValue).longValue());
}
if (javaType == float.class || javaType == double.class) {
return Double.compare(((Number) leftValue).doubleValue(), ((Number) rightValue).doubleValue());
}
if (javaType == Slice.class) {
return ((Slice) leftValue).compareTo((Slice) rightValue);
}
// value is some random type (say regex), so we just randomly choose a greater value
// todo: support all known type
return -1;
}
if (left instanceof InputReferenceExpression) {
return Integer.compare(((InputReferenceExpression) left).getField(), ((InputReferenceExpression) right).getField());
}
throw new IllegalArgumentException("Unsupported RowExpression type " + left.getClass().getSimpleName());
}
}
private static class ListComparator<T>
implements Comparator<List<T>>
{
private final Comparator<T> elementComparator;
public ListComparator(Comparator<T> elementComparator)
{
this.elementComparator = requireNonNull(elementComparator, "elementComparator is null");
}
@Override
public int compare(List<T> left, List<T> right)
{
int compareLength = min(left.size(), right.size());
for (int i = 0; i < compareLength; i++) {
int result = elementComparator.compare(left.get(i), right.get(i));
if (result != 0) {
return result;
}
}
return Integer.compare(left.size(), right.size());
}
}
private static <T> List<T> swapPair(List<T> pair)
{
requireNonNull(pair, "pair is null");
checkArgument(pair.size() == 2, "Expected pair to have two elements");
return ImmutableList.of(pair.get(1), pair.get(0));
}
}