/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator.scalar; import com.facebook.presto.Session; import com.facebook.presto.metadata.FunctionListBuilder; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.SqlFunction; import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.spi.ErrorCodeSupplier; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.function.OperatorType; import com.facebook.presto.spi.type.DecimalParseResult; import com.facebook.presto.spi.type.Decimals; import com.facebook.presto.spi.type.SqlDecimal; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.analyzer.SemanticErrorCode; import com.facebook.presto.sql.analyzer.SemanticException; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import java.math.BigInteger; import java.util.List; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.metadata.FunctionRegistry.mangleOperatorName; import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.facebook.presto.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; import static com.facebook.presto.spi.type.DecimalType.createDecimalType; import static com.facebook.presto.type.UnknownType.UNKNOWN; import static io.airlift.testing.Closeables.closeAllRuntimeException; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static org.testng.Assert.assertEquals; import static org.testng.Assert.fail; public abstract class AbstractTestFunctions { private final Session session; private final FeaturesConfig config; protected FunctionAssertions functionAssertions; protected AbstractTestFunctions() { this(TEST_SESSION); } protected AbstractTestFunctions(Session session) { this(session, new FeaturesConfig()); } protected AbstractTestFunctions(FeaturesConfig config) { this(TEST_SESSION, config); } protected AbstractTestFunctions(Session session, FeaturesConfig config) { this.session = requireNonNull(session, "session is null"); this.config = requireNonNull(config, "config is null"); } @BeforeClass public final void initTestFunctions() { functionAssertions = new FunctionAssertions(session, config); } @AfterClass(alwaysRun = true) public final void destroyTestFunctions() { closeAllRuntimeException(functionAssertions); functionAssertions = null; } protected void assertFunction(String projection, Type expectedType, Object expected) { functionAssertions.assertFunction(projection, expectedType, expected); } protected void assertOperator(OperatorType operator, String value, Type expectedType, Object expected) { functionAssertions.assertFunction(format("\"%s\"(%s)", mangleOperatorName(operator), value), expectedType, expected); } protected void assertDecimalFunction(String statement, SqlDecimal expectedResult) { assertFunction(statement, createDecimalType(expectedResult.getPrecision(), expectedResult.getScale()), expectedResult); } protected void assertInvalidFunction(String projection) { try { evaluateInvalid(projection); fail("Expected to fail"); } catch (RuntimeException e) { // Expected } } protected void assertInvalidFunction(String projection, String message) { try { evaluateInvalid(projection); fail("Expected to throw an INVALID_FUNCTION_ARGUMENT exception with message " + message); } catch (PrestoException e) { assertEquals(e.getErrorCode(), INVALID_FUNCTION_ARGUMENT.toErrorCode()); assertEquals(e.getMessage(), message); } } protected void assertInvalidFunction(String projection, SemanticErrorCode expectedErrorCode) { try { evaluateInvalid(projection); fail(format("Expected to throw %s exception", expectedErrorCode)); } catch (SemanticException e) { assertEquals(e.getCode(), expectedErrorCode); } } protected void assertInvalidFunction(String projection, SemanticErrorCode expectedErrorCode, String message) { try { evaluateInvalid(projection); fail(format("Expected to throw %s exception", expectedErrorCode)); } catch (SemanticException e) { assertEquals(e.getCode(), expectedErrorCode); assertEquals(e.getMessage(), message); } } protected void assertInvalidFunction(String projection, ErrorCodeSupplier expectedErrorCode) { try { evaluateInvalid(projection); fail(format("Expected to throw %s exception", expectedErrorCode.toErrorCode())); } catch (PrestoException e) { assertEquals(e.getErrorCode(), expectedErrorCode.toErrorCode()); } } protected void assertNumericOverflow(String projection, String message) { try { evaluateInvalid(projection); fail("Expected to throw an NUMERIC_VALUE_OUT_OF_RANGE exception with message " + message); } catch (PrestoException e) { assertEquals(e.getErrorCode(), NUMERIC_VALUE_OUT_OF_RANGE.toErrorCode()); assertEquals(e.getMessage(), message); } } protected void assertInvalidCast(String projection) { try { evaluateInvalid(projection); fail("Expected to throw an INVALID_CAST_ARGUMENT exception"); } catch (PrestoException e) { assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode()); } } protected void assertInvalidCast(String projection, String message) { try { evaluateInvalid(projection); fail("Expected to throw an INVALID_CAST_ARGUMENT exception"); } catch (PrestoException e) { assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode()); assertEquals(e.getMessage(), message); } } protected void registerScalarFunction(SqlScalarFunction sqlScalarFunction) { Metadata metadata = functionAssertions.getMetadata(); metadata.getFunctionRegistry().addFunctions(ImmutableList.of(sqlScalarFunction)); } protected void registerScalar(Class<?> clazz) { Metadata metadata = functionAssertions.getMetadata(); List<SqlFunction> functions = new FunctionListBuilder() .scalars(clazz) .getFunctions(); metadata.getFunctionRegistry().addFunctions(functions); } protected void registerParametricScalar(Class<?> clazz) { Metadata metadata = functionAssertions.getMetadata(); List<SqlFunction> functions = new FunctionListBuilder() .scalar(clazz) .getFunctions(); metadata.getFunctionRegistry().addFunctions(functions); } protected static SqlDecimal decimal(String decimalString) { DecimalParseResult parseResult = Decimals.parseIncludeLeadingZerosInPrecision(decimalString); BigInteger unscaledValue; if (parseResult.getType().isShort()) { unscaledValue = BigInteger.valueOf((Long) parseResult.getObject()); } else { unscaledValue = Decimals.decodeUnscaledValue((Slice) parseResult.getObject()); } return new SqlDecimal(unscaledValue, parseResult.getType().getPrecision(), parseResult.getType().getScale()); } protected static SqlDecimal maxPrecisionDecimal(long value) { final String maxPrecisionFormat = "%0" + (Decimals.MAX_PRECISION + (value < 0 ? 1 : 0)) + "d"; return decimal(String.format(maxPrecisionFormat, value)); } private void evaluateInvalid(String projection) { // type isn't necessary as the function is not valid functionAssertions.assertFunction(projection, UNKNOWN, null); } }