/*
* Licensed to CRATE Technology GmbH ("Crate") under one or more contributor
* license agreements. See the NOTICE file distributed with this work for
* additional information regarding copyright ownership. Crate licenses
* this file to you under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* However, if you have executed another commercial license agreement
* with Crate these terms will supersede the license and you may use the
* software solely pursuant to the terms of the relevant commercial agreement.
*/
package io.crate.operation.scalar;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.crate.action.sql.SessionContext;
import io.crate.analyze.relations.AnalyzedRelation;
import io.crate.analyze.relations.DocTableRelation;
import io.crate.analyze.symbol.*;
import io.crate.data.Input;
import io.crate.metadata.*;
import io.crate.metadata.doc.DocSchemaInfo;
import io.crate.metadata.doc.DocTableInfo;
import io.crate.metadata.table.TestingTableInfo;
import io.crate.operation.aggregation.FunctionExpression;
import io.crate.sql.tree.QualifiedName;
import io.crate.test.integration.CrateUnitTest;
import io.crate.testing.SqlExpressions;
import io.crate.types.ArrayType;
import io.crate.types.DataType;
import io.crate.types.DataTypes;
import io.crate.types.SetType;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.lucene.BytesRefs;
import org.hamcrest.Matcher;
import org.hamcrest.Matchers;
import org.junit.Before;
import java.util.*;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.core.Is.is;
public abstract class AbstractScalarFunctionsTest extends CrateUnitTest {
private static final InputApplier INPUT_APPLIER = new InputApplier();
protected SqlExpressions sqlExpressions;
protected Functions functions;
protected Map<QualifiedName, AnalyzedRelation> tableSources;
@Before
public void prepareFunctions() throws Exception {
DocTableInfo tableInfo = TestingTableInfo.builder(new TableIdent(DocSchemaInfo.NAME, "users"), null)
.add("id", DataTypes.INTEGER)
.add("name", DataTypes.STRING)
.add("tags", new ArrayType(DataTypes.STRING))
.add("age", DataTypes.INTEGER)
.add("a", DataTypes.INTEGER)
.add("x", DataTypes.LONG)
.add("shape", DataTypes.GEO_SHAPE)
.add("timestamp", DataTypes.TIMESTAMP)
.add("timezone", DataTypes.STRING)
.add("interval", DataTypes.STRING)
.add("time_format", DataTypes.STRING)
.add("long_array", new ArrayType(DataTypes.LONG))
.add("int_array", new ArrayType(DataTypes.INTEGER))
.add("array_string_array", new ArrayType(new ArrayType(DataTypes.STRING)))
.add("long_set", new SetType(DataTypes.LONG))
.add("regex_pattern", DataTypes.STRING)
.add("geoshape", DataTypes.GEO_SHAPE)
.add("geopoint", DataTypes.GEO_POINT)
.add("geostring", DataTypes.STRING)
.add("is_awesome", DataTypes.BOOLEAN)
.add("double_val", DataTypes.DOUBLE)
.add("float_val", DataTypes.DOUBLE)
.add("short_val", DataTypes.SHORT)
.add("obj", DataTypes.OBJECT, ImmutableList.of())
.build();
DocTableRelation tableRelation = new DocTableRelation(tableInfo);
tableSources = ImmutableMap.of(new QualifiedName("users"), tableRelation);
sqlExpressions = new SqlExpressions(tableSources);
functions = sqlExpressions.functions();
}
/**
* Assert that the functionExpression normalizes to the expectedSymbol
* <p>
* If the result of normalize is a Literal and all arguments were Literals evaluate is also called and
* compared to the result of normalize - the resulting value of normalize must match evaluate.
*/
@SuppressWarnings("unchecked")
public void assertNormalize(String functionExpression, Matcher<? super Symbol> expectedSymbol) {
assertNormalize(functionExpression, expectedSymbol, true);
}
public void assertNormalize(String functionExpression, Matcher<? super Symbol> expectedSymbol, boolean evaluate) {
Symbol functionSymbol = sqlExpressions.asSymbol(functionExpression);
if (functionSymbol instanceof Literal) {
assertThat(functionSymbol, expectedSymbol);
return;
}
Function function = (Function) functionSymbol;
FunctionImplementation impl = functions.getQualified(function.info().ident());
assertThat(impl, Matchers.notNullValue());
Symbol normalized = sqlExpressions.normalize(function);
assertThat(
String.format(Locale.ENGLISH, "expected <%s> to normalize to %s", functionExpression, expectedSymbol),
normalized,
expectedSymbol);
if (evaluate && normalized instanceof Input && allArgsAreInputs(function.arguments())) {
Input[] inputs = new Input[function.arguments().size()];
for (int i = 0; i < inputs.length; i++) {
inputs[i] = ((Input) function.arguments().get(i));
}
Object expectedValue = ((Input) normalized).value();
assertThat(((Scalar) impl).evaluate(inputs), is(expectedValue));
assertThat(((Scalar) impl).compile(function.arguments()).evaluate(inputs), is(expectedValue));
}
}
/**
* asserts that the given functionExpression evaluates to the expectedValue.
* If the functionExpression contains references the inputs will be used in the order the references appear.
* <p>
* E.g.
* <code>
* assertEvaluate("foo(name, age)", "expectedValue", inputForName, inputForAge)
* </code>
* or
* <code>
* assertEvaluate("foo('literalName', age)", "expectedValue", inputForAge)
* </code>
*/
@SuppressWarnings("unchecked")
public void assertEvaluate(String functionExpression, Object expectedValue, Input... inputs) {
Symbol functionSymbol = sqlExpressions.asSymbol(functionExpression);
functionSymbol = sqlExpressions.normalize(functionSymbol);
if (expectedValue instanceof String) {
expectedValue = new BytesRef((String) expectedValue);
}
if (functionSymbol instanceof Literal) {
assertThat(((Literal) functionSymbol).value(), is(expectedValue));
return;
}
Function function = (Function) functionSymbol;
Scalar scalar = (Scalar) functions.getQualified(function.info().ident());
InputApplierContext inputApplierContext = new InputApplierContext(inputs, sqlExpressions);
AssertingInput[] arguments = new AssertingInput[function.arguments().size()];
for (int i = 0; i < function.arguments().size(); i++) {
Symbol arg = function.arguments().get(i);
if (arg instanceof Input) {
arguments[i] = new AssertingInput(((Input) arg));
} else {
arguments[i] = new AssertingInput(INPUT_APPLIER.process(arg, inputApplierContext));
}
}
if (expectedValue instanceof BytesRef) {
// readable output for the AssertionError
Object actualValue = scalar.compile(function.arguments()).evaluate((Input[]) arguments);
assertThat(BytesRefs.toString(actualValue), is(BytesRefs.toString(expectedValue)));
} else {
assertThat(scalar.compile(function.arguments()).evaluate((Input[]) arguments), is(expectedValue));
}
for (AssertingInput argument : arguments) {
argument.calls = 0;
}
assertThat(scalar.evaluate((Input[]) arguments), is(expectedValue));
}
public void assertCompile(String functionExpression, java.util.function.Function<Scalar, Matcher<Scalar>> matcher) {
Symbol functionSymbol = sqlExpressions.asSymbol(functionExpression);
functionSymbol = sqlExpressions.normalize(functionSymbol);
assertThat("function expression was normalized, compile would not be hit", functionSymbol, not(instanceOf(Literal.class)));
Function function = (Function) functionSymbol;
Scalar scalar = (Scalar) functions.getQualified(function.info().ident());
Scalar compiled = scalar.compile(function.arguments());
assertThat(compiled, matcher.apply(scalar));
}
private static boolean allArgsAreInputs(List<Symbol> arguments) {
for (Symbol argument : arguments) {
if (!(argument instanceof Input)) {
return false;
}
}
return true;
}
@SuppressWarnings("unchecked")
protected <T extends FunctionImplementation> T getFunction(String functionName, DataType... argTypes) {
return (T) getFunction(functionName, Arrays.asList(argTypes));
}
@SuppressWarnings("unchecked")
protected <T extends FunctionImplementation> T getFunction(String functionName, List<DataType> argTypes) {
return (T) functions.getBuiltin(functionName, argTypes);
}
protected Symbol normalize(String functionName, Object value, DataType type) {
return normalize(functionName, Literal.of(type, value));
}
protected Symbol normalize(TransactionContext transactionContext, String functionName, Symbol... args) {
DataType[] argTypes = new DataType[args.length];
for (int i = 0; i < args.length; i++) {
argTypes[i] = args[i].valueType();
}
FunctionImplementation function = getFunction(functionName, argTypes);
return function.normalizeSymbol(new Function(function.info(),
Arrays.asList(args)), transactionContext);
}
protected Symbol normalize(String functionName, Symbol... args) {
return normalize(new TransactionContext(SessionContext.SYSTEM_SESSION), functionName, args);
}
private class AssertingInput implements Input {
private final Input delegate;
int calls = 0;
AssertingInput(Input delegate) {
this.delegate = delegate;
}
@Override
public Object value() {
calls++;
if (calls == 1) {
return delegate.value();
}
throw new AssertionError("Input.value() should only be called once");
}
@Override
public String toString() {
return delegate.toString();
}
}
private static class InputApplierContext implements Iterator<Input> {
private final Iterator<Input> inputsIterator;
private final SqlExpressions sqlExpressions;
InputApplierContext(Input[] inputs, SqlExpressions sqlExpressions) {
this.inputsIterator = Arrays.asList(inputs).iterator();
this.sqlExpressions = sqlExpressions;
}
public Input next() {
return inputsIterator.next();
}
@Override
public boolean hasNext() {
return inputsIterator.hasNext();
}
@Override
public void remove() {
}
}
/**
* Replace {@link Field} symbols with {@link Input} symbols found in the context.
* This way one can use column identifiers for scalar testing while providing the (literal) inputs the
* column should result in.
*/
private static class InputApplier extends SymbolVisitor<InputApplierContext, Input> {
@Override
public Input visitLiteral(Literal symbol, InputApplierContext context) {
return symbol;
}
@Override
public Input visitField(Field field, InputApplierContext context) {
if (!context.hasNext()) {
return null;
}
return context.next();
}
@Override
public Input visitFunction(Function function, InputApplierContext context) {
Input[] argInputs = new Input[function.arguments().size()];
for (int j = 0; j < function.arguments().size(); j++) {
Input input = function.arguments().get(j).accept(this, context);
argInputs[j] = input;
// replace arguments on function for normalization
if (input instanceof Literal) {
function.arguments().set(j, (Literal) argInputs[j]);
}
}
try {
return (Input) context.sqlExpressions.normalize(function);
} catch (Exception e) {
FunctionIdent ident = function.info().ident();
Scalar scalar =
(Scalar) context.sqlExpressions.functions().getBuiltin(ident.name(), ident.argumentTypes());
return new FunctionExpression<>(scalar, argInputs);
}
}
}
}