/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.operator.aggregation.state.StateCompiler;
import com.facebook.presto.spi.type.DecimalType;
import com.facebook.presto.spi.type.SqlDecimal;
import com.facebook.presto.spi.type.StandardTypes;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.type.RowType;
import com.google.common.collect.ImmutableList;
import org.testng.annotations.Test;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import static com.facebook.presto.block.BlockAssertions.createArrayBigintBlock;
import static com.facebook.presto.block.BlockAssertions.createDoublesBlock;
import static com.facebook.presto.block.BlockAssertions.createLongDecimalsBlock;
import static com.facebook.presto.block.BlockAssertions.createLongsBlock;
import static com.facebook.presto.block.BlockAssertions.createShortDecimalsBlock;
import static com.facebook.presto.block.BlockAssertions.createStringsBlock;
import static com.facebook.presto.metadata.FunctionKind.AGGREGATE;
import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation;
import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.DoubleType.DOUBLE;
import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.spi.type.VarcharType.VARCHAR;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Arrays.asList;
import static org.testng.Assert.assertNotNull;
public class TestMinMaxByAggregation
{
private static final MetadataManager METADATA = MetadataManager.createTestMetadataManager();
@Test
public void testAllRegistered()
{
Set<Type> orderableTypes = getTypes().stream()
.filter(Type::isOrderable)
.collect(toImmutableSet());
for (Type keyType : orderableTypes) {
for (Type valueType : getTypes()) {
if (StateCompiler.getSupportedFieldTypes().contains(valueType.getJavaType())) {
assertNotNull(METADATA.getFunctionRegistry().getAggregateFunctionImplementation(new Signature("min_by", AGGREGATE, valueType.getTypeSignature(), valueType.getTypeSignature(), keyType.getTypeSignature())));
assertNotNull(METADATA.getFunctionRegistry().getAggregateFunctionImplementation(new Signature("max_by", AGGREGATE, valueType.getTypeSignature(), valueType.getTypeSignature(), keyType.getTypeSignature())));
}
}
}
}
private static List<Type> getTypes()
{
List<Type> simpleTypes = METADATA.getTypeManager().getTypes();
return new ImmutableList.Builder<Type>()
.addAll(simpleTypes)
.add(VARCHAR)
.add(DecimalType.createDecimalType(1))
.add(new RowType(ImmutableList.of(BIGINT, VARCHAR, DOUBLE), Optional.empty()))
.build();
}
@Test
public void testMinNull()
{
InternalAggregationFunction function = METADATA.getFunctionRegistry().getAggregateFunctionImplementation(
new Signature("min_by", AGGREGATE, parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE)));
assertAggregation(
function,
1.0,
createDoublesBlock(1.0, null),
createDoublesBlock(1.0, 2.0));
assertAggregation(
function,
10.0,
createDoublesBlock(10.0, 9.0, 8.0, 11.0),
createDoublesBlock(1.0, null, 2.0, null));
}
@Test
public void testMaxNull()
{
InternalAggregationFunction function = METADATA.getFunctionRegistry().getAggregateFunctionImplementation(
new Signature("max_by", AGGREGATE, parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE)));
assertAggregation(
function,
null,
createDoublesBlock(1.0, null),
createDoublesBlock(1.0, 2.0));
assertAggregation(
function,
10.0,
createDoublesBlock(8.0, 9.0, 10.0, 11.0),
createDoublesBlock(-2.0, null, -1.0, null));
}
@Test
public void testMinDoubleDouble()
throws Exception
{
InternalAggregationFunction function = METADATA.getFunctionRegistry().getAggregateFunctionImplementation(
new Signature("min_by", AGGREGATE, parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE)));
assertAggregation(
function,
null,
createDoublesBlock(null, null),
createDoublesBlock(null, null));
assertAggregation(
function,
3.0,
createDoublesBlock(3.0, 2.0, 5.0, 3.0),
createDoublesBlock(1.0, 1.5, 2.0, 4.0));
}
@Test
public void testMaxDoubleDouble()
{
InternalAggregationFunction function = METADATA.getFunctionRegistry().getAggregateFunctionImplementation(
new Signature("max_by", AGGREGATE, parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE)));
assertAggregation(
function,
null,
createDoublesBlock(null, null),
createDoublesBlock(null, null));
assertAggregation(
function,
2.0,
createDoublesBlock(3.0, 2.0, null),
createDoublesBlock(1.0, 1.5, null));
}
@Test
public void testMinDoubleVarchar()
{
InternalAggregationFunction function = METADATA.getFunctionRegistry().getAggregateFunctionImplementation(
new Signature("min_by", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.DOUBLE)));
assertAggregation(
function,
"z",
createStringsBlock("z", "a", "x", "b"),
createDoublesBlock(1.0, 2.0, 2.0, 3.0));
assertAggregation(
function,
"a",
createStringsBlock("zz", "hi", "bb", "a"),
createDoublesBlock(0.0, 1.0, 2.0, -1.0));
}
@Test
public void testMaxDoubleVarchar()
{
InternalAggregationFunction function = METADATA.getFunctionRegistry().getAggregateFunctionImplementation(
new Signature("max_by", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.DOUBLE)));
assertAggregation(
function,
"a",
createStringsBlock("z", "a", null),
createDoublesBlock(1.0, 2.0, null));
assertAggregation(
function,
"hi",
createStringsBlock("zz", "hi", null, "a"),
createDoublesBlock(0.0, 1.0, null, -1.0));
}
@Test
public void testMinLongLongArray()
{
InternalAggregationFunction function = METADATA.getFunctionRegistry().getAggregateFunctionImplementation(
new Signature("min_by", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature("array(bigint)"), parseTypeSignature(StandardTypes.BIGINT)));
assertAggregation(
function,
ImmutableList.of(8L, 9L),
createArrayBigintBlock(ImmutableList.of(ImmutableList.of(8L, 9L), ImmutableList.of(1L, 2L), ImmutableList.of(6L, 7L), ImmutableList.of(2L, 3L))),
createLongsBlock(1L, 2L, 2L, 3L));
assertAggregation(
function,
ImmutableList.of(2L),
createArrayBigintBlock(ImmutableList.of(ImmutableList.of(8L, 9L), ImmutableList.of(6L, 7L), ImmutableList.of(2L, 3L), ImmutableList.of(2L))),
createLongsBlock(0L, 1L, 2L, -1L));
}
@Test
public void testMinLongArrayLong()
{
InternalAggregationFunction function = METADATA.getFunctionRegistry().getAggregateFunctionImplementation(
new Signature("min_by", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature("array(bigint)")));
assertAggregation(
function,
3L,
createLongsBlock(1L, 2L, 2L, 3L),
createArrayBigintBlock(ImmutableList.of(ImmutableList.of(8L, 9L), ImmutableList.of(1L, 2L), ImmutableList.of(6L, 7L), ImmutableList.of(1L, 1L))));
assertAggregation(
function,
-1L,
createLongsBlock(0L, 1L, 2L, -1L),
createArrayBigintBlock(ImmutableList.of(ImmutableList.of(8L, 9L), ImmutableList.of(6L, 7L), ImmutableList.of(-1L, -3L), ImmutableList.of(-1L))));
}
@Test
public void testMaxLongArrayLong()
{
InternalAggregationFunction function = METADATA.getFunctionRegistry().getAggregateFunctionImplementation(
new Signature("max_by", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature("array(bigint)")));
assertAggregation(
function,
1L,
createLongsBlock(1L, 2L, 2L, 3L),
createArrayBigintBlock(ImmutableList.of(ImmutableList.of(8L, 9L), ImmutableList.of(1L, 2L), ImmutableList.of(6L, 7L), ImmutableList.of(1L, 1L))));
assertAggregation(
function,
2L,
createLongsBlock(0L, 1L, 2L, -1L),
createArrayBigintBlock(ImmutableList.of(ImmutableList.of(-8L, 9L), ImmutableList.of(-6L, 7L), ImmutableList.of(-1L, -3L), ImmutableList.of(-1L))));
}
@Test
public void testMaxLongLongArray()
{
InternalAggregationFunction function = METADATA.getFunctionRegistry().getAggregateFunctionImplementation(
new Signature("max_by", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature("array(bigint)"), parseTypeSignature(StandardTypes.BIGINT)));
assertAggregation(
function,
ImmutableList.of(1L, 2L),
createArrayBigintBlock(asList(asList(3L, 4L), asList(1L, 2L), null)),
createLongsBlock(1L, 2L, null));
assertAggregation(
function,
ImmutableList.of(2L, 3L),
createArrayBigintBlock(asList(asList(3L, 4L), asList(2L, 3L), null, asList(1L, 2L))),
createLongsBlock(0L, 1L, null, -1L));
}
@Test
public void testMinLongDecimalDecimal()
{
InternalAggregationFunction function = METADATA.getFunctionRegistry().getAggregateFunctionImplementation(new Signature("min_by", AGGREGATE, parseTypeSignature("decimal(19,1)"), parseTypeSignature("decimal(19,1)"), parseTypeSignature("decimal(19,1)")));
assertAggregation(
function,
SqlDecimal.of("2.2"),
createLongDecimalsBlock("1.1", "2.2", "3.3"),
createLongDecimalsBlock("1.2", "1.0", "2.0"));
}
@Test
public void testMaxLongDecimalDecimal()
{
InternalAggregationFunction function = METADATA.getFunctionRegistry().getAggregateFunctionImplementation(new Signature("max_by", AGGREGATE, parseTypeSignature("decimal(19,1)"), parseTypeSignature("decimal(19,1)"), parseTypeSignature("decimal(19,1)")));
assertAggregation(
function,
SqlDecimal.of("3.3"),
createLongDecimalsBlock("1.1", "2.2", "3.3", "4.4"),
createLongDecimalsBlock("1.2", "1.0", "2.0", "1.5"));
}
@Test
public void testMinShortDecimalDecimal()
{
InternalAggregationFunction function = METADATA.getFunctionRegistry().getAggregateFunctionImplementation(new Signature("min_by", AGGREGATE, parseTypeSignature("decimal(10,1)"), parseTypeSignature("decimal(10,1)"), parseTypeSignature("decimal(10,1)")));
assertAggregation(
function,
SqlDecimal.of("2.2"),
createShortDecimalsBlock("1.1", "2.2", "3.3"),
createShortDecimalsBlock("1.2", "1.0", "2.0"));
}
@Test
public void testMaxShortDecimalDecimal()
{
InternalAggregationFunction function = METADATA.getFunctionRegistry().getAggregateFunctionImplementation(new Signature("max_by", AGGREGATE, parseTypeSignature("decimal(10,1)"), parseTypeSignature("decimal(10,1)"), parseTypeSignature("decimal(10,1)")));
assertAggregation(
function,
SqlDecimal.of("3.3"),
createShortDecimalsBlock("1.1", "2.2", "3.3", "4.4"),
createShortDecimalsBlock("1.2", "1.0", "2.0", "1.5"));
}
}