/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator.aggregation; import com.facebook.presto.block.BlockEncodingManager; import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; import com.facebook.presto.spi.block.RunLengthEncodedBlock; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.analyzer.TypeSignatureProvider; import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.Lists; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import java.util.List; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypeSignatures; public abstract class AbstractTestAggregationFunction { protected TypeRegistry typeRegistry; protected FunctionRegistry functionRegistry; @BeforeClass public final void initTestAggregationFunction() { typeRegistry = new TypeRegistry(); functionRegistry = new FunctionRegistry(typeRegistry, new BlockEncodingManager(typeRegistry), new FeaturesConfig()); } @AfterClass(alwaysRun = true) public final void destroyTestAggregationFunction() { functionRegistry = null; typeRegistry = null; } public abstract Block[] getSequenceBlocks(int start, int length); protected final InternalAggregationFunction getFunction() { List<TypeSignatureProvider> parameterTypes = fromTypeSignatures(Lists.transform(getFunctionParameterTypes(), TypeSignature::parseTypeSignature)); Signature signature = functionRegistry.resolveFunction(QualifiedName.of(getFunctionName()), parameterTypes); return functionRegistry.getAggregateFunctionImplementation(signature); } protected abstract String getFunctionName(); protected abstract List<String> getFunctionParameterTypes(); public abstract Object getExpectedValue(int start, int length); public Object getExpectedValueIncludingNulls(int start, int length, int lengthIncludingNulls) { return getExpectedValue(start, length); } @Test public void testNoPositions() { testAggregation(getExpectedValue(0, 0), getSequenceBlocks(0, 0)); } @Test public void testSinglePosition() { testAggregation(getExpectedValue(0, 1), getSequenceBlocks(0, 1)); } @Test public void testMultiplePositions() { testAggregation(getExpectedValue(0, 5), getSequenceBlocks(0, 5)); } @Test public void testAllPositionsNull() throws Exception { // if there are no parameters skip this test List<Type> parameterTypes = getFunction().getParameterTypes(); if (parameterTypes.isEmpty()) { return; } Block[] blocks = new Block[parameterTypes.size()]; for (int i = 0; i < parameterTypes.size(); i++) { Block nullValueBlock = parameterTypes.get(0).createBlockBuilder(new BlockBuilderStatus(), 1) .appendNull() .build(); blocks[i] = new RunLengthEncodedBlock(nullValueBlock, 10); } testAggregation(getExpectedValueIncludingNulls(0, 0, 10), blocks); } @Test public void testMixedNullAndNonNullPositions() { // if there are no parameters skip this test List<Type> parameterTypes = getFunction().getParameterTypes(); if (parameterTypes.isEmpty()) { return; } Block[] alternatingNullsBlocks = createAlternatingNullsBlock(parameterTypes, getSequenceBlocks(0, 10)); testAggregation(getExpectedValueIncludingNulls(0, 10, 20), alternatingNullsBlocks); } @Test public void testNegativeOnlyValues() { testAggregation(getExpectedValue(-10, 5), getSequenceBlocks(-10, 5)); } @Test public void testPositiveOnlyValues() { testAggregation(getExpectedValue(2, 4), getSequenceBlocks(2, 4)); } public Block[] createAlternatingNullsBlock(List<Type> types, Block... sequenceBlocks) { Block[] alternatingNullsBlocks = new Block[sequenceBlocks.length]; for (int i = 0; i < sequenceBlocks.length; i++) { int positionCount = sequenceBlocks[i].getPositionCount(); Type type = types.get(i); BlockBuilder blockBuilder = type.createBlockBuilder(new BlockBuilderStatus(), positionCount); for (int position = 0; position < positionCount; position++) { // append null blockBuilder.appendNull(); // append value type.appendTo(sequenceBlocks[i], position, blockBuilder); } alternatingNullsBlocks[i] = blockBuilder.build(); } return alternatingNullsBlocks; } protected void testAggregation(Object expectedValue, Block... blocks) { assertAggregation(getFunction(), expectedValue, blocks); } }