/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner;
import com.facebook.presto.sql.tree.AstVisitor;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FieldReference;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.SymbolReference;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Objects.requireNonNull;
/**
* Currently this class handles only simple expressions like:
*
* A.a < B.x.
*
* It could be extended to handle any expressions like:
*
* A.a * sin(A.b) / log(B.x) < cos(B.z)
*
* by transforming it to:
*
* f(A.a, A.b) < g(B.x, B.z)
*
* Where f(...) and g(...) would be some functions/expressions. That
* would allow us to perform binary search on arbitrary complex expressions
* by sorting position links according to the result of f(...) function.
*/
public final class SortExpressionExtractor
{
private SortExpressionExtractor() {}
public static Optional<Expression> extractSortExpression(Set<Symbol> buildSymbols, Expression filter)
{
if (filter instanceof ComparisonExpression) {
ComparisonExpression comparison = (ComparisonExpression) filter;
switch (comparison.getType()) {
case GREATER_THAN:
case GREATER_THAN_OR_EQUAL:
case LESS_THAN:
case LESS_THAN_OR_EQUAL:
Optional<SymbolReference> sortChannel = asBuildSymbolReference(buildSymbols, comparison.getRight());
boolean hasBuildReferencesOnOtherSide = hasBuildSymbolReference(buildSymbols, comparison.getLeft());
if (!sortChannel.isPresent()) {
sortChannel = asBuildSymbolReference(buildSymbols, comparison.getLeft());
hasBuildReferencesOnOtherSide = hasBuildSymbolReference(buildSymbols, comparison.getRight());
}
if (sortChannel.isPresent() && !hasBuildReferencesOnOtherSide) {
return sortChannel.map(symbolReference -> (Expression) symbolReference);
}
return Optional.empty();
default:
return Optional.empty();
}
}
return Optional.empty();
}
private static Optional<SymbolReference> asBuildSymbolReference(Set<Symbol> buildLayout, Expression expression)
{
if (expression instanceof SymbolReference) {
SymbolReference symbolReference = (SymbolReference) expression;
if (buildLayout.contains(new Symbol(symbolReference.getName()))) {
return Optional.of(symbolReference);
}
}
return Optional.empty();
}
private static boolean hasBuildSymbolReference(Set<Symbol> buildSymbols, Expression expression)
{
return new BuildSymbolReferenceFinder(buildSymbols).process(expression);
}
private static class BuildSymbolReferenceFinder
extends AstVisitor<Boolean, Void>
{
private final Set<String> buildSymbols;
public BuildSymbolReferenceFinder(Set<Symbol> buildSymbols)
{
this.buildSymbols = requireNonNull(buildSymbols, "buildSymbols is null").stream()
.map(Symbol::getName)
.collect(toImmutableSet());
}
@Override
protected Boolean visitNode(Node node, Void context)
{
for (Node child : node.getChildren()) {
if (process(child, context)) {
return true;
}
}
return false;
}
@Override
protected Boolean visitSymbolReference(SymbolReference symbolReference, Void context)
{
return buildSymbols.contains(symbolReference.getName());
}
}
public static class SortExpression
{
private final int channel;
public SortExpression(int channel)
{
this.channel = channel;
}
public int getChannel()
{
return channel;
}
@Override
public boolean equals(Object obj)
{
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
SortExpression other = (SortExpression) obj;
return Objects.equals(this.channel, other.channel);
}
@Override
public int hashCode()
{
return Objects.hash(channel);
}
public String toString()
{
return toStringHelper(this)
.add("channel", channel)
.toString();
}
public static SortExpression fromExpression(Expression expression)
{
checkState(expression instanceof FieldReference, "Unsupported expression type [%s]", expression);
return new SortExpression(((FieldReference) expression).getFieldIndex());
}
}
}