package io.crate.analyze.where; import com.google.common.collect.ImmutableSet; import io.crate.analyze.WhereClause; import io.crate.analyze.symbol.Field; import io.crate.analyze.symbol.Function; import io.crate.analyze.symbol.Symbol; import io.crate.analyze.symbol.SymbolVisitor; import io.crate.metadata.Reference; import io.crate.operation.operator.EqOperator; import io.crate.operation.operator.GteOperator; import io.crate.operation.operator.any.AnyEqOperator; import io.crate.operation.predicate.NotPredicate; import io.crate.sql.tree.ComparisonExpression; import java.util.Locale; import java.util.Set; import java.util.Stack; public abstract class WhereClauseValidator { private static final Visitor visitor = new Visitor(); public static void validate(WhereClause whereClause) { if (whereClause.hasQuery()) { visitor.process(whereClause.query(), new Visitor.Context()); } } private static class Visitor extends SymbolVisitor<Visitor.Context, Symbol> { static class Context { private final Stack<Function> functions = new Stack<>(); private Context() { } } private static final String _SCORE = "_score"; private static final Set<String> SCORE_ALLOWED_COMPARISONS = ImmutableSet.of(GteOperator.NAME); private static final String _VERSION = "_version"; private static final Set<String> VERSION_ALLOWED_COMPARISONS = ImmutableSet.of(EqOperator.NAME, AnyEqOperator.NAME); private static final String VERSION_ERROR = "Filtering \"_version\" in WHERE clause only works using the \"=\" operator, checking for a numeric value"; private static final String SCORE_ERROR = String.format(Locale.ENGLISH, "System column '%s' can only be used within a '%s' comparison without any surrounded predicate", _SCORE, ComparisonExpression.Type.GREATER_THAN_OR_EQUAL.getValue()); @Override public Symbol visitField(Field field, Context context) { validateSysReference(context, field.path().outputName()); return super.visitField(field, context); } @Override public Symbol visitReference(Reference symbol, Context context) { validateSysReference(context, symbol.ident().columnIdent().name()); return super.visitReference(symbol, context); } @Override public Symbol visitFunction(Function function, Context context) { context.functions.push(function); continueTraversal(function, context); context.functions.pop(); return function; } private Function continueTraversal(Function symbol, Context context) { for (Symbol argument : symbol.arguments()) { process(argument, context); } return symbol; } private boolean insideNotPredicate(Context context) { for (Function function : context.functions) { if (function.info().ident().name().equals(NotPredicate.NAME)) { return true; } } return false; } private void validateSysReference(Context context, String columnName) { if (columnName.equalsIgnoreCase(_VERSION)) { validateSysReference(context, VERSION_ALLOWED_COMPARISONS, VERSION_ERROR); } else if (columnName.equalsIgnoreCase(_SCORE)) { validateSysReference(context, SCORE_ALLOWED_COMPARISONS, SCORE_ERROR); } } private void validateSysReference(Context context, Set<String> requiredFunctionNames, String error) { if (context.functions.isEmpty()) { throw new UnsupportedOperationException(error); } Function function = context.functions.lastElement(); if (!requiredFunctionNames.contains(function.info().ident().name().toLowerCase(Locale.ENGLISH)) || insideNotPredicate(context)) { throw new UnsupportedOperationException(error); } assert function.arguments().size() == 2 : "function's number of arguments must be 2"; Symbol right = function.arguments().get(1); if (!right.symbolType().isValueSymbol()) { throw new UnsupportedOperationException(error); } } } }