/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hive.ql.exec.vector;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import java.lang.management.ManagementFactory;
import java.lang.management.MemoryMXBean;
import java.lang.reflect.Constructor;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.hadoop.hive.common.type.HiveDecimal;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.CompilationOpContext;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.OperatorFactory;
import org.apache.hadoop.hive.ql.exec.vector.util.FakeCaptureOutputOperator;
import org.apache.hadoop.hive.ql.exec.vector.util.FakeVectorRowBatchFromConcat;
import org.apache.hadoop.hive.ql.exec.vector.util.FakeVectorRowBatchFromLongIterables;
import org.apache.hadoop.hive.ql.exec.vector.util.FakeVectorRowBatchFromObjectIterables;
import org.apache.hadoop.hive.ql.exec.vector.util.FakeVectorRowBatchFromRepeats;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.optimizer.physical.Vectorizer;
import org.apache.hadoop.hive.ql.plan.AggregationDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.GroupByDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.VectorGroupByDesc;
import org.apache.hadoop.hive.ql.plan.VectorGroupByDesc.ProcessingMode;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.ByteWritable;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
import org.apache.hadoop.hive.serde2.io.ShortWritable;
import org.apache.hadoop.hive.serde2.io.TimestampWritable;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.apache.hadoop.io.BooleanWritable;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.junit.Assert;
import org.junit.Test;
/**
* Unit test for the vectorized GROUP BY operator.
*/
public class TestVectorGroupByOperator {
HiveConf hconf = new HiveConf();
private static ExprNodeDesc buildColumnDesc(
VectorizationContext ctx,
String column,
TypeInfo typeInfo) {
return new ExprNodeColumnDesc(
typeInfo, column, "table", false);
}
private static AggregationDesc buildAggregationDesc(
VectorizationContext ctx,
String aggregate,
GenericUDAFEvaluator.Mode mode,
String column,
TypeInfo typeInfo) {
ExprNodeDesc inputColumn = buildColumnDesc(ctx, column, typeInfo);
ArrayList<ExprNodeDesc> params = new ArrayList<ExprNodeDesc>();
params.add(inputColumn);
AggregationDesc agg = new AggregationDesc();
agg.setGenericUDAFName(aggregate);
agg.setMode(mode);
agg.setParameters(params);
return agg;
}
private static AggregationDesc buildAggregationDescCountStar(
VectorizationContext ctx) {
AggregationDesc agg = new AggregationDesc();
agg.setGenericUDAFName("COUNT");
agg.setMode(GenericUDAFEvaluator.Mode.PARTIAL1);
agg.setParameters(new ArrayList<ExprNodeDesc>());
return agg;
}
private static GroupByDesc buildGroupByDescType(
VectorizationContext ctx,
String aggregate,
GenericUDAFEvaluator.Mode mode,
String column,
TypeInfo dataType) {
AggregationDesc agg = buildAggregationDesc(ctx, aggregate, mode,
column, dataType);
ArrayList<AggregationDesc> aggs = new ArrayList<AggregationDesc>();
aggs.add(agg);
ArrayList<String> outputColumnNames = new ArrayList<String>();
outputColumnNames.add("_col0");
GroupByDesc desc = new GroupByDesc();
desc.setVectorDesc(new VectorGroupByDesc());
desc.setOutputColumnNames(outputColumnNames);
desc.setAggregators(aggs);
((VectorGroupByDesc) desc.getVectorDesc()).setProcessingMode(ProcessingMode.GLOBAL);
return desc;
}
private static GroupByDesc buildGroupByDescCountStar(
VectorizationContext ctx) {
AggregationDesc agg = buildAggregationDescCountStar(ctx);
ArrayList<AggregationDesc> aggs = new ArrayList<AggregationDesc>();
aggs.add(agg);
ArrayList<String> outputColumnNames = new ArrayList<String>();
outputColumnNames.add("_col0");
GroupByDesc desc = new GroupByDesc();
desc.setVectorDesc(new VectorGroupByDesc());
desc.setOutputColumnNames(outputColumnNames);
desc.setAggregators(aggs);
return desc;
}
private static GroupByDesc buildKeyGroupByDesc(
VectorizationContext ctx,
String aggregate,
String column,
TypeInfo dataTypeInfo,
String key,
TypeInfo keyTypeInfo) {
GroupByDesc desc = buildGroupByDescType(ctx, aggregate, GenericUDAFEvaluator.Mode.PARTIAL1, column, dataTypeInfo);
((VectorGroupByDesc) desc.getVectorDesc()).setProcessingMode(ProcessingMode.HASH);
ExprNodeDesc keyExp = buildColumnDesc(ctx, key, keyTypeInfo);
ArrayList<ExprNodeDesc> keys = new ArrayList<ExprNodeDesc>();
keys.add(keyExp);
desc.setKeys(keys);
desc.getOutputColumnNames().add("_col1");
return desc;
}
long outputRowCount = 0;
@Test
public void testMemoryPressureFlush() throws HiveException {
List<String> mapColumnNames = new ArrayList<String>();
mapColumnNames.add("Key");
mapColumnNames.add("Value");
VectorizationContext ctx = new VectorizationContext("name", mapColumnNames);
GroupByDesc desc = buildKeyGroupByDesc (ctx, "max",
"Value", TypeInfoFactory.longTypeInfo,
"Key", TypeInfoFactory.longTypeInfo);
// Set the memory treshold so that we get 100Kb before we need to flush.
MemoryMXBean memoryMXBean = ManagementFactory.getMemoryMXBean();
long maxMemory = memoryMXBean.getHeapMemoryUsage().getMax();
float treshold = 100.0f*1024.0f/maxMemory;
desc.setMemoryThreshold(treshold);
CompilationOpContext cCtx = new CompilationOpContext();
Operator<? extends OperatorDesc> groupByOp = OperatorFactory.get(cCtx, desc);
VectorGroupByOperator vgo =
(VectorGroupByOperator) Vectorizer.vectorizeGroupByOperator(groupByOp, ctx);
FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(cCtx, vgo);
vgo.initialize(hconf, null);
this.outputRowCount = 0;
out.setOutputInspector(new FakeCaptureOutputOperator.OutputInspector() {
@Override
public void inspectRow(Object row, int tag) throws HiveException {
++outputRowCount;
}
});
Iterable<Object> it = new Iterable<Object>() {
@Override
public Iterator<Object> iterator() {
return new Iterator<Object> () {
long value = 0;
@Override
public boolean hasNext() {
return true;
}
@Override
public Object next() {
return ++value;
}
@Override
public void remove() {
}
};
}
};
FakeVectorRowBatchFromObjectIterables data = new FakeVectorRowBatchFromObjectIterables(
100,
new String[] {"long", "long"},
it,
it);
// The 'it' data source will produce data w/o ever ending
// We want to see that memory pressure kicks in and some
// entries in the VGBY are flushed.
long countRowsProduced = 0;
for (VectorizedRowBatch unit: data) {
countRowsProduced += 100;
vgo.process(unit, 0);
if (0 < outputRowCount) {
break;
}
// Set an upper bound how much we're willing to push before it should flush
// we've set the memory treshold at 100kb, each key is distinct
// It should not go beyond 100k/16 (key+data)
assertTrue(countRowsProduced < 100*1024/16);
}
assertTrue(0 < outputRowCount);
}
@Test
public void testMultiKeyIntStringInt() throws HiveException {
testMultiKey(
"sum",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"int", "string", "int", "double"},
Arrays.asList(new Object[]{null, 1, 1, null, 2, 2, null}),
Arrays.asList(new Object[]{ "A", "A", "A", "C", null, null, "A"}),
Arrays.asList(new Object[]{null, 2, 2, null, 2, 2, null}),
Arrays.asList(new Object[]{1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0})),
buildHashMap(
Arrays.asList( 1, "A", 2), 6.0,
Arrays.asList(null, "C", null), 8.0,
Arrays.asList( 2, null, 2), 48.0,
Arrays.asList(null, "A", null), 65.0));
}
@Test
public void testMultiKeyStringByteString() throws HiveException {
testMultiKey(
"sum",
new FakeVectorRowBatchFromObjectIterables(
1,
new String[] {"string", "tinyint", "string", "double"},
Arrays.asList(new Object[]{"A", "A", null}),
Arrays.asList(new Object[]{ 1, 1, 1}),
Arrays.asList(new Object[]{ "A", "A", "A"}),
Arrays.asList(new Object[]{ 1.0, 1.0, 1.0})),
buildHashMap(
Arrays.asList( "A", (byte)1, "A"), 2.0,
Arrays.asList( null, (byte)1, "A"), 1.0));
}
@Test
public void testMultiKeyStringIntString() throws HiveException {
testMultiKey(
"sum",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"string", "int", "string", "double"},
Arrays.asList(new Object[]{ "A", "A", "A", "C", null, null, "A"}),
Arrays.asList(new Object[]{null, 1, 1, null, 2, 2, null}),
Arrays.asList(new Object[]{ "A", "A", "A", "C", null, null, "A"}),
Arrays.asList(new Object[]{ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0})),
buildHashMap(
Arrays.asList(null, 2, null), 2.0,
Arrays.asList( "C", null, "C"), 1.0,
Arrays.asList( "A", 1, "A"), 2.0,
Arrays.asList( "A", null, "A"), 2.0));
}
@Test
public void testMultiKeyIntStringString() throws HiveException {
testMultiKey(
"sum",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"int", "string", "string", "double"},
Arrays.asList(new Object[]{null, 1, 1, null, 2, 2, null}),
Arrays.asList(new Object[]{ "A", "A", "A", "C", null, null, "A"}),
Arrays.asList(new Object[]{ "A", "A", "A", "C", null, null, "A"}),
Arrays.asList(new Object[]{ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0})),
buildHashMap(
Arrays.asList( 2, null, null), 2.0,
Arrays.asList(null, "C", "C"), 1.0,
Arrays.asList( 1, "A", "A"), 2.0,
Arrays.asList(null, "A", "A"), 2.0));
}
@Test
public void testMultiKeyDoubleStringInt() throws HiveException {
testMultiKey(
"sum",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"double", "string", "int", "double"},
Arrays.asList(new Object[]{null, 1.0, 1.0, null, 2.0, 2.0, null}),
Arrays.asList(new Object[]{ "A", "A", "A", "C", null, null, "A"}),
Arrays.asList(new Object[]{null, 2, 2, null, 2, 2, null}),
Arrays.asList(new Object[]{1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0})),
buildHashMap(
Arrays.asList( 1.0, "A", 2), 6.0,
Arrays.asList(null, "C", null), 8.0,
Arrays.asList( 2.0, null, 2), 48.0,
Arrays.asList(null, "A", null), 65.0));
}
@Test
public void testMultiKeyDoubleShortString() throws HiveException {
short s = 2;
testMultiKey(
"sum",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"double", "smallint", "string", "double"},
Arrays.asList(new Object[]{null, 1.0, 1.0, null, 2.0, 2.0, null}),
Arrays.asList(new Object[]{null, s, s, null, s, s, null}),
Arrays.asList(new Object[]{ "A", "A", "A", "C", null, null, "A"}),
Arrays.asList(new Object[]{1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0})),
buildHashMap(
Arrays.asList( 1.0, s, "A"), 6.0,
Arrays.asList(null, null, "C"), 8.0,
Arrays.asList( 2.0, s, null), 48.0,
Arrays.asList(null, null, "A"), 65.0));
}
@Test
public void testDoubleValueTypeSum() throws HiveException {
testKeyTypeAggregate(
"sum",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"tinyint", "double"},
Arrays.asList(new Object[]{ 1,null, 1, null}),
Arrays.asList(new Object[]{13.0,null,7.0, 19.0})),
buildHashMap((byte)1, 20.0, null, 19.0));
}
@Test
public void testDoubleValueTypeSumOneKey() throws HiveException {
testKeyTypeAggregate(
"sum",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"tinyint", "double"},
Arrays.asList(new Object[]{ 1, 1, 1, 1}),
Arrays.asList(new Object[]{13.0,null,7.0, 19.0})),
buildHashMap((byte)1, 39.0));
}
@Test
public void testDoubleValueTypeCount() throws HiveException {
testKeyTypeAggregate(
"count",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"tinyint", "double"},
Arrays.asList(new Object[]{ 1,null, 1, null}),
Arrays.asList(new Object[]{13.0,null,7.0, 19.0})),
buildHashMap((byte)1, 2L, null, 1L));
}
public void testDoubleValueTypeCountOneKey() throws HiveException {
testKeyTypeAggregate(
"count",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"tinyint", "double"},
Arrays.asList(new Object[]{ 1, 1, 1, 1}),
Arrays.asList(new Object[]{13.0,null,7.0, 19.0})),
buildHashMap((byte)1, 3L));
}
@Test
public void testDoubleValueTypeAvg() throws HiveException {
testKeyTypeAggregate(
"avg",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"tinyint", "double"},
Arrays.asList(new Object[]{ 1,null, 1, null}),
Arrays.asList(new Object[]{13.0,null,7.0, 19.0})),
buildHashMap((byte)1, 10.0, null, 19.0));
}
@Test
public void testDoubleValueTypeAvgOneKey() throws HiveException {
testKeyTypeAggregate(
"avg",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"tinyint", "double"},
Arrays.asList(new Object[]{ 1, 1, 1, 1}),
Arrays.asList(new Object[]{13.0,null,7.0, 19.0})),
buildHashMap((byte)1, 13.0));
}
@Test
public void testDoubleValueTypeMin() throws HiveException {
testKeyTypeAggregate(
"min",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"tinyint", "double"},
Arrays.asList(new Object[]{ 1,null, 1, null}),
Arrays.asList(new Object[]{13.0,null,7.0, 19.0})),
buildHashMap((byte)1, 7.0, null, 19.0));
}
@Test
public void testDoubleValueTypeMinOneKey() throws HiveException {
testKeyTypeAggregate(
"min",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"tinyint", "double"},
Arrays.asList(new Object[]{ 1, 1, 1, 1}),
Arrays.asList(new Object[]{13.0,null,7.0, 19.0})),
buildHashMap((byte)1, 7.0));
}
@Test
public void testDoubleValueTypeMax() throws HiveException {
testKeyTypeAggregate(
"max",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"tinyint", "double"},
Arrays.asList(new Object[]{ 1,null, 1, null}),
Arrays.asList(new Object[]{13.0,null,7.0, 19.0})),
buildHashMap((byte)1, 13.0, null, 19.0));
}
@Test
public void testDoubleValueTypeMaxOneKey() throws HiveException {
testKeyTypeAggregate(
"max",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"tinyint", "double"},
Arrays.asList(new Object[]{ 1, 1, 1, 1}),
Arrays.asList(new Object[]{13.0,null,7.0, 19.0})),
buildHashMap((byte)1, 19.0));
}
@Test
public void testDoubleValueTypeVariance() throws HiveException {
testKeyTypeAggregate(
"variance",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"tinyint", "double"},
Arrays.asList(new Object[]{ 1,null, 1, null}),
Arrays.asList(new Object[]{13.0,null,7.0, 19.0})),
buildHashMap((byte)1, 9.0, null, 0.0));
}
@Test
public void testDoubleValueTypeVarianceOneKey() throws HiveException {
testKeyTypeAggregate(
"variance",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"tinyint", "double"},
Arrays.asList(new Object[]{ 1, 1, 1, 1}),
Arrays.asList(new Object[]{13.0,null,7.0, 19.0})),
buildHashMap((byte)1, 24.0));
}
@Test
public void testTinyintKeyTypeAggregate() throws HiveException {
testKeyTypeAggregate(
"sum",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"tinyint", "bigint"},
Arrays.asList(new Object[]{ 1,null, 1, null}),
Arrays.asList(new Object[]{13L,null,7L, 19L})),
buildHashMap((byte)1, 20L, null, 19L));
}
@Test
public void testSmallintKeyTypeAggregate() throws HiveException {
testKeyTypeAggregate(
"sum",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"smallint", "bigint"},
Arrays.asList(new Object[]{ 1,null, 1, null}),
Arrays.asList(new Object[]{13L,null,7L, 19L})),
buildHashMap((short)1, 20L, null, 19L));
}
@Test
public void testIntKeyTypeAggregate() throws HiveException {
testKeyTypeAggregate(
"sum",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"int", "bigint"},
Arrays.asList(new Object[]{ 1,null, 1, null}),
Arrays.asList(new Object[]{13L,null,7L, 19L})),
buildHashMap(1, 20L, null, 19L));
}
@Test
public void testBigintKeyTypeAggregate() throws HiveException {
testKeyTypeAggregate(
"sum",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"bigint", "bigint"},
Arrays.asList(new Object[]{ 1,null, 1, null}),
Arrays.asList(new Object[]{13L,null,7L, 19L})),
buildHashMap(1L, 20L, null, 19L));
}
@Test
public void testBooleanKeyTypeAggregate() throws HiveException {
testKeyTypeAggregate(
"sum",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"boolean", "bigint"},
Arrays.asList(new Object[]{ true,null, true, null}),
Arrays.asList(new Object[]{13L,null,7L, 19L})),
buildHashMap(true, 20L, null, 19L));
}
@Test
public void testTimestampKeyTypeAggregate() throws HiveException {
testKeyTypeAggregate(
"sum",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"timestamp", "bigint"},
Arrays.asList(new Object[]{new Timestamp(1),null, new Timestamp(1), null}),
Arrays.asList(new Object[]{13L,null,7L, 19L})),
buildHashMap(new Timestamp(1), 20L, null, 19L));
}
@Test
public void testFloatKeyTypeAggregate() throws HiveException {
testKeyTypeAggregate(
"sum",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"float", "bigint"},
Arrays.asList(new Object[]{ 1,null, 1, null}),
Arrays.asList(new Object[]{13L,null,7L, 19L})),
buildHashMap((float)1.0, 20L, null, 19L));
}
@Test
public void testDoubleKeyTypeAggregate() throws HiveException {
testKeyTypeAggregate(
"sum",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"double", "bigint"},
Arrays.asList(new Object[]{ 1,null, 1, null}),
Arrays.asList(new Object[]{13L,null,7L, 19L})),
buildHashMap(1.0, 20L, null, 19L));
}
@Test
public void testCountStar() throws HiveException {
testAggregateCountStar(
2,
Arrays.asList(new Long[]{13L,null,7L,19L}),
4L);
}
@Test
public void testCountReduce() throws HiveException {
testAggregateCountReduce(
2,
Arrays.asList(new Long[]{}),
0L);
testAggregateCountReduce(
2,
Arrays.asList(new Long[]{0L}),
0L);
testAggregateCountReduce(
2,
Arrays.asList(new Long[]{0L,0L}),
0L);
testAggregateCountReduce(
2,
Arrays.asList(new Long[]{0L,1L,0L}),
1L);
testAggregateCountReduce(
2,
Arrays.asList(new Long[]{13L,0L,7L,19L}),
39L);
}
@Test
public void testCountDecimal() throws HiveException {
testAggregateDecimal(
"Decimal",
"count",
2,
Arrays.asList(new Object[]{
HiveDecimal.create(1),
HiveDecimal.create(2),
HiveDecimal.create(3)}),
3L);
}
@Test
public void testMaxDecimal() throws HiveException {
testAggregateDecimal(
"Decimal",
"max",
2,
Arrays.asList(new Object[]{
HiveDecimal.create(1),
HiveDecimal.create(2),
HiveDecimal.create(3)}),
HiveDecimal.create(3));
testAggregateDecimal(
"Decimal",
"max",
2,
Arrays.asList(new Object[]{
HiveDecimal.create(3),
HiveDecimal.create(2),
HiveDecimal.create(1)}),
HiveDecimal.create(3));
testAggregateDecimal(
"Decimal",
"max",
2,
Arrays.asList(new Object[]{
HiveDecimal.create(2),
HiveDecimal.create(3),
HiveDecimal.create(1)}),
HiveDecimal.create(3));
}
@Test
public void testMinDecimal() throws HiveException {
testAggregateDecimal(
"Decimal",
"min",
2,
Arrays.asList(new Object[]{
HiveDecimal.create(1),
HiveDecimal.create(2),
HiveDecimal.create(3)}),
HiveDecimal.create(1));
testAggregateDecimal(
"Decimal",
"min",
2,
Arrays.asList(new Object[]{
HiveDecimal.create(3),
HiveDecimal.create(2),
HiveDecimal.create(1)}),
HiveDecimal.create(1));
testAggregateDecimal(
"Decimal",
"min",
2,
Arrays.asList(new Object[]{
HiveDecimal.create(2),
HiveDecimal.create(1),
HiveDecimal.create(3)}),
HiveDecimal.create(1));
}
@Test
public void testSumDecimal() throws HiveException {
testAggregateDecimal(
"Decimal",
"sum",
2,
Arrays.asList(new Object[]{
HiveDecimal.create(1),
HiveDecimal.create(2),
HiveDecimal.create(3)}),
HiveDecimal.create(1+2+3));
}
@Test
public void testSumDecimalHive6508() throws HiveException {
short scale = 4;
testAggregateDecimal(
"Decimal(10,4)",
"sum",
4,
Arrays.asList(new Object[]{
HiveDecimal.create("1234.2401"),
HiveDecimal.create("1868.52"),
HiveDecimal.ZERO,
HiveDecimal.create("456.84"),
HiveDecimal.create("121.89")}),
HiveDecimal.create("3681.4901"));
}
@Test
public void testAvgDecimal() throws HiveException {
testAggregateDecimal(
"Decimal",
"avg",
2,
Arrays.asList(new Object[]{
HiveDecimal.create(1),
HiveDecimal.create(2),
HiveDecimal.create(3)}),
HiveDecimal.create((1+2+3)/3));
}
@Test
public void testAvgDecimalNegative() throws HiveException {
testAggregateDecimal(
"Decimal",
"avg",
2,
Arrays.asList(new Object[]{
HiveDecimal.create(-1),
HiveDecimal.create(-2),
HiveDecimal.create(-3)}),
HiveDecimal.create((-1-2-3)/3));
}
@Test
public void testVarianceDecimal () throws HiveException {
testAggregateDecimal(
"Decimal",
"variance",
2,
Arrays.asList(new Object[]{
HiveDecimal.create(13),
HiveDecimal.create(5),
HiveDecimal.create(7),
HiveDecimal.create(19)}),
(double) 30);
}
@Test
public void testVarSampDecimal () throws HiveException {
testAggregateDecimal(
"Decimal",
"var_samp",
2,
Arrays.asList(new Object[]{
HiveDecimal.create(13),
HiveDecimal.create(5),
HiveDecimal.create(7),
HiveDecimal.create(19)}),
(double) 40);
}
@Test
public void testStdPopDecimal () throws HiveException {
testAggregateDecimal(
"Decimal",
"stddev_pop",
2,
Arrays.asList(new Object[]{
HiveDecimal.create(13),
HiveDecimal.create(5),
HiveDecimal.create(7),
HiveDecimal.create(19)}),
Math.sqrt(30));
}
@Test
public void testStdSampDecimal () throws HiveException {
testAggregateDecimal(
"Decimal",
"stddev_samp",
2,
Arrays.asList(new Object[]{
HiveDecimal.create(13),
HiveDecimal.create(5),
HiveDecimal.create(7),
HiveDecimal.create(19)}),
Math.sqrt(40));
}
@Test
public void testDecimalKeyTypeAggregate() throws HiveException {
testKeyTypeAggregate(
"sum",
new FakeVectorRowBatchFromObjectIterables(
2,
new String[] {"decimal(38,0)", "bigint"},
Arrays.asList(new Object[]{
HiveDecimal.create(1),null,
HiveDecimal.create(1), null}),
Arrays.asList(new Object[]{13L,null,7L, 19L})),
buildHashMap(HiveDecimal.create(1), 20L, null, 19L));
}
@Test
public void testCountString() throws HiveException {
testAggregateString(
"count",
2,
Arrays.asList(new Object[]{"A","B","C"}),
3L);
}
@Test
public void testMaxString() throws HiveException {
testAggregateString(
"max",
2,
Arrays.asList(new Object[]{"A","B","C"}),
"C");
testAggregateString(
"max",
2,
Arrays.asList(new Object[]{"C", "B", "A"}),
"C");
}
@Test
public void testMinString() throws HiveException {
testAggregateString(
"min",
2,
Arrays.asList(new Object[]{"A","B","C"}),
"A");
testAggregateString(
"min",
2,
Arrays.asList(new Object[]{"C", "B", "A"}),
"A");
}
@Test
public void testMaxNullString() throws HiveException {
testAggregateString(
"max",
2,
Arrays.asList(new Object[]{"A","B",null}),
"B");
testAggregateString(
"max",
2,
Arrays.asList(new Object[]{null, null, null}),
null);
}
@Test
public void testCountStringWithNull() throws HiveException {
testAggregateString(
"count",
2,
Arrays.asList(new Object[]{"A",null,"C", "D", null}),
3L);
}
@Test
public void testCountStringAllNull() throws HiveException {
testAggregateString(
"count",
4,
Arrays.asList(new Object[]{null, null, null, null, null}),
0L);
}
@Test
public void testMinLongNullStringKeys() throws HiveException {
testAggregateStringKeyAggregate(
"min",
2,
Arrays.asList(new Object[]{"A",null,"A",null}),
Arrays.asList(new Object[]{13L, 5L, 7L,19L}),
buildHashMap("A", 7L, null, 5L));
}
@Test
public void testMinLongStringKeys() throws HiveException {
testAggregateStringKeyAggregate(
"min",
2,
Arrays.asList(new Object[]{"A","B","A","B"}),
Arrays.asList(new Object[]{13L, 5L, 7L,19L}),
buildHashMap("A", 7L, "B", 5L));
}
@Test
public void testMinLongKeyGroupByCompactBatch() throws HiveException {
testAggregateLongKeyAggregate(
"min",
2,
Arrays.asList(new Long[]{01L,1L,2L,02L}),
Arrays.asList(new Long[]{13L,5L,7L,19L}),
buildHashMap(1L, 5L, 2L, 7L));
}
@Test
public void testMinLongKeyGroupBySingleBatch() throws HiveException {
testAggregateLongKeyAggregate(
"min",
4,
Arrays.asList(new Long[]{01L,1L,2L,02L}),
Arrays.asList(new Long[]{13L,5L,7L,19L}),
buildHashMap(1L, 5L, 2L, 7L));
}
@Test
public void testMinLongKeyGroupByCrossBatch() throws HiveException {
testAggregateLongKeyAggregate(
"min",
2,
Arrays.asList(new Long[]{01L,2L,1L,02L}),
Arrays.asList(new Long[]{13L,5L,7L,19L}),
buildHashMap(1L, 7L, 2L, 5L));
}
@Test
public void testMinLongNullKeyGroupByCrossBatch() throws HiveException {
testAggregateLongKeyAggregate(
"min",
2,
Arrays.asList(new Long[]{null,2L,null,02L}),
Arrays.asList(new Long[]{13L,5L,7L,19L}),
buildHashMap(null, 7L, 2L, 5L));
}
@Test
public void testMinLongNullKeyGroupBySingleBatch() throws HiveException {
testAggregateLongKeyAggregate(
"min",
4,
Arrays.asList(new Long[]{null,2L,null,02L}),
Arrays.asList(new Long[]{13L,5L,7L,19L}),
buildHashMap(null, 7L, 2L, 5L));
}
@Test
public void testMaxLongNullKeyGroupBySingleBatch() throws HiveException {
testAggregateLongKeyAggregate(
"max",
4,
Arrays.asList(new Long[]{null,2L,null,02L}),
Arrays.asList(new Long[]{13L,5L,7L,19L}),
buildHashMap(null, 13L, 2L, 19L));
}
@Test
public void testCountLongNullKeyGroupBySingleBatch() throws HiveException {
testAggregateLongKeyAggregate(
"count",
4,
Arrays.asList(new Long[]{null,2L,null,02L}),
Arrays.asList(new Long[]{13L,5L,7L,19L}),
buildHashMap(null, 2L, 2L, 2L));
}
@Test
public void testSumLongNullKeyGroupBySingleBatch() throws HiveException {
testAggregateLongKeyAggregate(
"sum",
4,
Arrays.asList(new Long[]{null,2L,null,02L}),
Arrays.asList(new Long[]{13L,5L,7L,19L}),
buildHashMap(null, 20L, 2L, 24L));
}
@Test
public void testAvgLongNullKeyGroupBySingleBatch() throws HiveException {
testAggregateLongKeyAggregate(
"avg",
4,
Arrays.asList(new Long[]{null,2L,null,02L}),
Arrays.asList(new Long[]{13L,5L,7L,19L}),
buildHashMap(null, 10.0, 2L, 12.0));
}
@Test
public void testVarLongNullKeyGroupBySingleBatch() throws HiveException {
testAggregateLongKeyAggregate(
"variance",
4,
Arrays.asList(new Long[]{null,2L,01L,02L,01L,01L}),
Arrays.asList(new Long[]{13L, 5L,18L,19L,12L,15L}),
buildHashMap(null, 0.0, 2L, 49.0, 01L, 6.0));
}
@Test
public void testMinNullLongNullKeyGroupBy() throws HiveException {
testAggregateLongKeyAggregate(
"min",
4,
Arrays.asList(new Long[]{null,2L,null,02L}),
Arrays.asList(new Long[]{null, null, null, null}),
buildHashMap(null, null, 2L, null));
}
@Test
public void testMinLongGroupBy() throws HiveException {
testAggregateLongAggregate(
"min",
2,
Arrays.asList(new Long[]{13L,5L,7L,19L}),
5L);
}
@Test
public void testMinLongSimple() throws HiveException {
testAggregateLongAggregate(
"min",
2,
Arrays.asList(new Long[]{13L,5L,7L,19L}),
5L);
}
@Test
public void testMinLongEmpty() throws HiveException {
testAggregateLongAggregate(
"min",
2,
Arrays.asList(new Long[]{}),
null);
}
@Test
public void testMinLongNulls() throws HiveException {
testAggregateLongAggregate(
"min",
2,
Arrays.asList(new Long[]{null}),
null);
testAggregateLongAggregate(
"min",
2,
Arrays.asList(new Long[]{null, null, null}),
null);
testAggregateLongAggregate(
"min",
2,
Arrays.asList(new Long[]{null,5L,7L,19L}),
5L);
testAggregateLongAggregate(
"min",
2,
Arrays.asList(new Long[]{13L,null,7L,19L}),
7L);
}
@Test
public void testMinLongRepeat () throws HiveException {
testAggregateLongRepeats (
"min",
42L,
4096,
1024,
42L);
}
@Test
public void testMinLongRepeatNulls () throws HiveException {
testAggregateLongRepeats (
"min",
null,
4096,
1024,
null);
}
@Test
public void testMinLongNegative () throws HiveException {
testAggregateLongAggregate(
"min",
2,
Arrays.asList(new Long[]{13L,5L,7L,-19L}),
-19L);
}
@Test
public void testMinLongMinInt () throws HiveException {
testAggregateLongAggregate(
"min",
2,
Arrays.asList(new Long[]{13L,5L,(long)Integer.MIN_VALUE,-19L}),
(long)Integer.MIN_VALUE);
}
@Test
public void testMinLongMinLong () throws HiveException {
testAggregateLongAggregate(
"min",
2,
Arrays.asList(new Long[]{13L,5L, Long.MIN_VALUE, (long)Integer.MIN_VALUE}),
Long.MIN_VALUE);
}
@Test
public void testMaxLongSimple () throws HiveException {
testAggregateLongAggregate(
"max",
2,
Arrays.asList(new Long[]{13L,5L,7L,19L}),
19L);
}
@Test
public void testMaxLongEmpty () throws HiveException {
testAggregateLongAggregate(
"max",
2,
Arrays.asList(new Long[]{}),
null);
}
@Test
public void testMaxLongNegative () throws HiveException {
testAggregateLongAggregate(
"max",
2,
Arrays.asList(new Long[]{-13L,-5L,-7L,-19L}),
-5L);
}
@Test
public void testMaxLongMaxInt () throws HiveException {
testAggregateLongAggregate(
"max",
2,
Arrays.asList(new Long[]{13L,5L,7L,(long)Integer.MAX_VALUE}),
(long)Integer.MAX_VALUE);
}
@Test
public void testMaxLongMaxLong () throws HiveException {
testAggregateLongAggregate(
"max",
2,
Arrays.asList(new Long[]{13L,Long.MAX_VALUE - 1L,Long.MAX_VALUE,(long)Integer.MAX_VALUE}),
Long.MAX_VALUE);
}
@Test
public void testMaxLongRepeat () throws HiveException {
testAggregateLongRepeats (
"max",
42L,
4096,
1024,
42L);
}
@Test
public void testMaxLongNulls () throws HiveException {
testAggregateLongRepeats (
"max",
null,
4096,
1024,
null);
}
@SuppressWarnings("unchecked")
@Test
public void testMinLongConcatRepeat () throws HiveException {
testAggregateLongIterable ("min",
new FakeVectorRowBatchFromConcat(
new FakeVectorRowBatchFromRepeats(
new Long[] {19L}, 10, 2),
new FakeVectorRowBatchFromRepeats(
new Long[] {7L}, 15, 2),
new FakeVectorRowBatchFromRepeats(
new Long[] {19L}, 10, 2)),
7L);
}
@SuppressWarnings("unchecked")
@Test
public void testMinLongRepeatConcatValues () throws HiveException {
testAggregateLongIterable ("min",
new FakeVectorRowBatchFromConcat(
new FakeVectorRowBatchFromRepeats(
new Long[] {19L}, 10, 2),
new FakeVectorRowBatchFromLongIterables(
3,
Arrays.asList(new Long[]{13L, 7L, 23L, 29L}))),
7L);
}
@Test
public void testCountLongSimple () throws HiveException {
testAggregateLongAggregate(
"count",
2,
Arrays.asList(new Long[]{13L,5L,7L,19L}),
4L);
}
@Test
public void testCountLongEmpty () throws HiveException {
testAggregateLongAggregate(
"count",
2,
Arrays.asList(new Long[]{}),
0L);
}
@Test
public void testCountLongNulls () throws HiveException {
testAggregateLongAggregate(
"count",
2,
Arrays.asList(new Long[]{null}),
0L);
testAggregateLongAggregate(
"count",
2,
Arrays.asList(new Long[]{null, null, null}),
0L);
testAggregateLongAggregate(
"count",
2,
Arrays.asList(new Long[]{null,5L,7L,19L}),
3L);
testAggregateLongAggregate(
"count",
2,
Arrays.asList(new Long[]{13L,null,7L,19L}),
3L);
}
@Test
public void testCountLongRepeat () throws HiveException {
testAggregateLongRepeats (
"count",
42L,
4096,
1024,
4096L);
}
@Test
public void testCountLongRepeatNulls () throws HiveException {
testAggregateLongRepeats (
"count",
null,
4096,
1024,
0L);
}
@SuppressWarnings("unchecked")
@Test
public void testCountLongRepeatConcatValues () throws HiveException {
testAggregateLongIterable ("count",
new FakeVectorRowBatchFromConcat(
new FakeVectorRowBatchFromRepeats(
new Long[] {19L}, 10, 2),
new FakeVectorRowBatchFromLongIterables(
3,
Arrays.asList(new Long[]{13L, 7L, 23L, 29L}))),
14L);
}
@Test
public void testSumDoubleSimple() throws HiveException {
testAggregateDouble(
"sum",
2,
Arrays.asList(new Object[]{13.0,5.0,7.0,19.0}),
13.0 + 5.0 + 7.0 + 19.0);
}
@Test
public void testSumDoubleGroupByString() throws HiveException {
testAggregateDoubleStringKeyAggregate(
"sum",
4,
Arrays.asList(new Object[]{"A", null, "A", null}),
Arrays.asList(new Object[]{13.0,5.0,7.0,19.0}),
buildHashMap("A", 20.0, null, 24.0));
}
@Test
public void testSumLongSimple () throws HiveException {
testAggregateLongAggregate(
"sum",
2,
Arrays.asList(new Long[]{13L,5L,7L,19L}),
13L + 5L + 7L + 19L);
}
@Test
public void testSumLongEmpty () throws HiveException {
testAggregateLongAggregate(
"sum",
2,
Arrays.asList(new Long[]{}),
null);
}
@Test
public void testSumLongNulls () throws HiveException {
testAggregateLongAggregate(
"sum",
2,
Arrays.asList(new Long[]{null}),
null);
testAggregateLongAggregate(
"sum",
2,
Arrays.asList(new Long[]{null, null, null}),
null);
testAggregateLongAggregate(
"sum",
2,
Arrays.asList(new Long[]{null,5L,7L,19L}),
5L + 7L + 19L);
testAggregateLongAggregate(
"sum",
2,
Arrays.asList(new Long[]{13L,null,7L,19L}),
13L + 7L + 19L);
}
@Test
public void testSumLongRepeat () throws HiveException {
testAggregateLongRepeats (
"sum",
42L,
4096,
1024,
4096L * 42L);
}
@Test
public void testSumLongRepeatNulls () throws HiveException {
testAggregateLongRepeats (
"sum",
null,
4096,
1024,
null);
}
@SuppressWarnings("unchecked")
@Test
public void testSumLongRepeatConcatValues () throws HiveException {
testAggregateLongIterable ("sum",
new FakeVectorRowBatchFromConcat(
new FakeVectorRowBatchFromRepeats(
new Long[] {19L}, 10, 2),
new FakeVectorRowBatchFromLongIterables(
3,
Arrays.asList(new Long[]{13L, 7L, 23L, 29L}))),
19L*10L + 13L + 7L + 23L +29L);
}
@Test
public void testSumLongZero () throws HiveException {
testAggregateLongAggregate(
"sum",
2,
Arrays.asList(new Long[]{-(long)Integer.MAX_VALUE, (long)Integer.MAX_VALUE}),
0L);
}
@Test
public void testSumLong2MaxInt () throws HiveException {
testAggregateLongAggregate(
"sum",
2,
Arrays.asList(new Long[]{(long)Integer.MAX_VALUE, (long)Integer.MAX_VALUE}),
4294967294L);
}
@Test
public void testSumLong2MinInt () throws HiveException {
testAggregateLongAggregate(
"sum",
2,
Arrays.asList(new Long[]{(long)Integer.MIN_VALUE, (long)Integer.MIN_VALUE}),
-4294967296L);
}
@Test
public void testSumLong2MaxLong () throws HiveException {
testAggregateLongAggregate(
"sum",
2,
Arrays.asList(new Long[]{Long.MAX_VALUE, Long.MAX_VALUE}),
-2L); // silent overflow
}
@Test
public void testSumLong2MinLong () throws HiveException {
testAggregateLongAggregate(
"sum",
2,
Arrays.asList(new Long[]{Long.MIN_VALUE, Long.MIN_VALUE}),
0L); // silent overflow
}
@Test
public void testSumLongMinMaxLong () throws HiveException {
testAggregateLongAggregate(
"sum",
2,
Arrays.asList(new Long[]{Long.MAX_VALUE, Long.MIN_VALUE}),
-1L);
}
@Test
public void testAvgLongSimple () throws HiveException {
testAggregateLongAggregate(
"avg",
2,
Arrays.asList(new Long[]{13L,5L,7L,19L}),
(double) (13L + 5L + 7L + 19L) / (double) 4L);
}
@Test
public void testAvgLongEmpty () throws HiveException {
testAggregateLongAggregate(
"avg",
2,
Arrays.asList(new Long[]{}),
null);
}
@Test
public void testAvgLongNulls () throws HiveException {
testAggregateLongAggregate(
"avg",
2,
Arrays.asList(new Long[]{null}),
null);
testAggregateLongAggregate(
"avg",
2,
Arrays.asList(new Long[]{null, null, null}),
null);
testAggregateLongAggregate(
"avg",
2,
Arrays.asList(new Long[]{null,5L,7L,19L}),
(double) (5L + 7L + 19L) / (double) 3L);
testAggregateLongAggregate(
"avg",
2,
Arrays.asList(new Long[]{13L,null,7L,19L}),
(double) (13L + + 7L + 19L) / (double) 3L);
}
@Test
public void testAvgLongRepeat () throws HiveException
{
testAggregateLongRepeats (
"avg",
42L,
4096,
1024,
(double)42);
}
@Test
public void testAvgLongRepeatNulls () throws HiveException {
testAggregateLongRepeats (
"avg",
null,
4096,
1024,
null);
}
@SuppressWarnings("unchecked")
@Test
public void testAvgLongRepeatConcatValues () throws HiveException {
testAggregateLongIterable ("avg",
new FakeVectorRowBatchFromConcat(
new FakeVectorRowBatchFromRepeats(
new Long[] {19L}, 10, 2),
new FakeVectorRowBatchFromLongIterables(
3,
Arrays.asList(new Long[]{13L, 7L, 23L, 29L}))),
(double) (19L*10L + 13L + 7L + 23L +29L) / (double) 14 );
}
@Test
public void testVarianceLongSimple () throws HiveException {
testAggregateLongAggregate(
"variance",
2,
Arrays.asList(new Long[]{13L,5L,7L,19L}),
(double) 30L);
}
@Test
public void testVarianceLongEmpty () throws HiveException {
testAggregateLongAggregate(
"variance",
2,
Arrays.asList(new Long[]{}),
null);
}
@Test
public void testVarianceLongSingle () throws HiveException {
testAggregateLongAggregate(
"variance",
2,
Arrays.asList(new Long[]{97L}),
0.0);
}
@Test
public void testVarianceLongNulls () throws HiveException {
testAggregateLongAggregate(
"variance",
2,
Arrays.asList(new Long[]{null}),
null);
testAggregateLongAggregate(
"variance",
2,
Arrays.asList(new Long[]{null, null, null}),
null);
testAggregateLongAggregate(
"variance",
2,
Arrays.asList(new Long[]{null,13L, 5L,7L,19L}),
30.0);
testAggregateLongAggregate(
"variance",
2,
Arrays.asList(new Long[]{13L,null,5L, 7L,19L}),
30.0);
testAggregateLongAggregate(
"variance",
2,
Arrays.asList(new Long[]{null,null,null,19L}),
(double) 0);
}
@Test
public void testVarPopLongRepeatNulls () throws HiveException {
testAggregateLongRepeats (
"var_pop",
null,
4096,
1024,
null);
}
@Test
public void testVarPopLongRepeat () throws HiveException {
testAggregateLongRepeats (
"var_pop",
42L,
4096,
1024,
(double)0);
}
@Test
public void testVarSampLongSimple () throws HiveException {
testAggregateLongAggregate(
"var_samp",
2,
Arrays.asList(new Long[]{13L,5L,7L,19L}),
(double) 40L);
}
@Test
public void testVarSampLongEmpty () throws HiveException {
testAggregateLongAggregate(
"var_samp",
2,
Arrays.asList(new Long[]{}),
null);
}
@Test
public void testVarSampLongRepeat () throws HiveException {
testAggregateLongRepeats (
"var_samp",
42L,
4096,
1024,
(double)0);
}
@Test
public void testStdLongSimple () throws HiveException {
testAggregateLongAggregate(
"std",
2,
Arrays.asList(new Long[]{13L,5L,7L,19L}),
Math.sqrt(30));
}
@Test
public void testStdLongEmpty () throws HiveException {
testAggregateLongAggregate(
"std",
2,
Arrays.asList(new Long[]{}),
null);
}
@Test
public void testStdDevLongRepeat () throws HiveException {
testAggregateLongRepeats (
"stddev",
42L,
4096,
1024,
(double)0);
}
@Test
public void testStdDevLongRepeatNulls () throws HiveException {
testAggregateLongRepeats (
"stddev",
null,
4096,
1024,
null);
}
@Test
public void testStdDevSampSimple () throws HiveException {
testAggregateLongAggregate(
"stddev_samp",
2,
Arrays.asList(new Long[]{13L,5L,7L,19L}),
Math.sqrt(40));
}
@Test
public void testStdDevSampLongRepeat () throws HiveException {
testAggregateLongRepeats (
"stddev_samp",
42L,
3,
1024,
(double)0);
}
private void testMultiKey(
String aggregateName,
FakeVectorRowBatchFromObjectIterables data,
HashMap<Object, Object> expected) throws HiveException {
Map<String, Integer> mapColumnNames = new HashMap<String, Integer>();
ArrayList<String> outputColumnNames = new ArrayList<String>();
ArrayList<ExprNodeDesc> keysDesc = new ArrayList<ExprNodeDesc>();
Set<Object> keys = new HashSet<Object>();
// The types array tells us the number of columns in the data
final String[] columnTypes = data.getTypes();
// Columns 0..N-1 are keys. Column N is the aggregate value input
int i=0;
for(; i<columnTypes.length - 1; ++i) {
String columnName = String.format("_col%d", i);
mapColumnNames.put(columnName, i);
outputColumnNames.add(columnName);
}
mapColumnNames.put("value", i);
outputColumnNames.add("value");
VectorizationContext ctx = new VectorizationContext("name", outputColumnNames);
ArrayList<AggregationDesc> aggs = new ArrayList(1);
aggs.add(
buildAggregationDesc(ctx, aggregateName, GenericUDAFEvaluator.Mode.PARTIAL1,
"value", TypeInfoFactory.getPrimitiveTypeInfo(columnTypes[i])));
for(i=0; i<columnTypes.length - 1; ++i) {
String columnName = String.format("_col%d", i);
keysDesc.add(
buildColumnDesc(ctx, columnName,
TypeInfoFactory.getPrimitiveTypeInfo(columnTypes[i])));
}
GroupByDesc desc = new GroupByDesc();
desc.setVectorDesc(new VectorGroupByDesc());
desc.setOutputColumnNames(outputColumnNames);
desc.setAggregators(aggs);
desc.setKeys(keysDesc);
((VectorGroupByDesc) desc.getVectorDesc()).setProcessingMode(ProcessingMode.HASH);
CompilationOpContext cCtx = new CompilationOpContext();
Operator<? extends OperatorDesc> groupByOp = OperatorFactory.get(cCtx, desc);
VectorGroupByOperator vgo =
(VectorGroupByOperator) Vectorizer.vectorizeGroupByOperator(groupByOp, ctx);
FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(cCtx, vgo);
vgo.initialize(hconf, null);
out.setOutputInspector(new FakeCaptureOutputOperator.OutputInspector() {
private int rowIndex;
private String aggregateName;
private Map<Object,Object> expected;
private Set<Object> keys;
@Override
public void inspectRow(Object row, int tag) throws HiveException {
assertTrue(row instanceof Object[]);
Object[] fields = (Object[]) row;
assertEquals(columnTypes.length, fields.length);
ArrayList<Object> keyValue = new ArrayList<Object>(columnTypes.length-1);
for(int i=0; i<columnTypes.length-1; ++i) {
Object key = fields[i];
if (null == key) {
keyValue.add(null);
} else if (key instanceof Text) {
Text txKey = (Text)key;
keyValue.add(txKey.toString());
} else if (key instanceof ByteWritable) {
ByteWritable bwKey = (ByteWritable)key;
keyValue.add(bwKey.get());
} else if (key instanceof ShortWritable) {
ShortWritable swKey = (ShortWritable)key;
keyValue.add(swKey.get());
} else if (key instanceof IntWritable) {
IntWritable iwKey = (IntWritable)key;
keyValue.add(iwKey.get());
} else if (key instanceof LongWritable) {
LongWritable lwKey = (LongWritable)key;
keyValue.add(lwKey.get());
} else if (key instanceof TimestampWritable) {
TimestampWritable twKey = (TimestampWritable)key;
keyValue.add(twKey.getTimestamp());
} else if (key instanceof DoubleWritable) {
DoubleWritable dwKey = (DoubleWritable)key;
keyValue.add(dwKey.get());
} else if (key instanceof FloatWritable) {
FloatWritable fwKey = (FloatWritable)key;
keyValue.add(fwKey.get());
} else if (key instanceof BooleanWritable) {
BooleanWritable bwKey = (BooleanWritable)key;
keyValue.add(bwKey.get());
} else {
Assert.fail(String.format("Not implemented key output type %s: %s",
key.getClass().getName(), key));
}
}
String keyAsString = Arrays.deepToString(keyValue.toArray());
assertTrue(expected.containsKey(keyValue));
Object expectedValue = expected.get(keyValue);
Object value = fields[columnTypes.length-1];
Validator validator = getValidator(aggregateName);
validator.validate(keyAsString, expectedValue, new Object[] {value});
keys.add(keyValue);
}
private FakeCaptureOutputOperator.OutputInspector init(
String aggregateName, Map<Object,Object> expected, Set<Object> keys) {
this.aggregateName = aggregateName;
this.expected = expected;
this.keys = keys;
return this;
}
}.init(aggregateName, expected, keys));
for (VectorizedRowBatch unit: data) {
vgo.process(unit, 0);
}
vgo.close(false);
List<Object> outBatchList = out.getCapturedRows();
assertNotNull(outBatchList);
assertEquals(expected.size(), outBatchList.size());
assertEquals(expected.size(), keys.size());
}
private void testKeyTypeAggregate(
String aggregateName,
FakeVectorRowBatchFromObjectIterables data,
Map<Object, Object> expected) throws HiveException {
List<String> mapColumnNames = new ArrayList<String>();
mapColumnNames.add("Key");
mapColumnNames.add("Value");
VectorizationContext ctx = new VectorizationContext("name", mapColumnNames);
Set<Object> keys = new HashSet<Object>();
AggregationDesc agg = buildAggregationDesc(ctx, aggregateName, GenericUDAFEvaluator.Mode.PARTIAL1,
"Value", TypeInfoFactory.getPrimitiveTypeInfo(data.getTypes()[1]));
ArrayList<AggregationDesc> aggs = new ArrayList<AggregationDesc>();
aggs.add(agg);
ArrayList<String> outputColumnNames = new ArrayList<String>();
outputColumnNames.add("_col0");
outputColumnNames.add("_col1");
GroupByDesc desc = new GroupByDesc();
desc.setVectorDesc(new VectorGroupByDesc());
desc.setOutputColumnNames(outputColumnNames);
desc.setAggregators(aggs);
((VectorGroupByDesc) desc.getVectorDesc()).setProcessingMode(ProcessingMode.HASH);
ExprNodeDesc keyExp = buildColumnDesc(ctx, "Key",
TypeInfoFactory.getPrimitiveTypeInfo(data.getTypes()[0]));
ArrayList<ExprNodeDesc> keysDesc = new ArrayList<ExprNodeDesc>();
keysDesc.add(keyExp);
desc.setKeys(keysDesc);
CompilationOpContext cCtx = new CompilationOpContext();
Operator<? extends OperatorDesc> groupByOp = OperatorFactory.get(cCtx, desc);
VectorGroupByOperator vgo =
(VectorGroupByOperator) Vectorizer.vectorizeGroupByOperator(groupByOp, ctx);
FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(cCtx, vgo);
vgo.initialize(hconf, null);
out.setOutputInspector(new FakeCaptureOutputOperator.OutputInspector() {
private int rowIndex;
private String aggregateName;
private Map<Object,Object> expected;
private Set<Object> keys;
@Override
public void inspectRow(Object row, int tag) throws HiveException {
assertTrue(row instanceof Object[]);
Object[] fields = (Object[]) row;
assertEquals(2, fields.length);
Object key = fields[0];
Object keyValue = null;
if (null == key) {
keyValue = null;
} else if (key instanceof ByteWritable) {
ByteWritable bwKey = (ByteWritable)key;
keyValue = bwKey.get();
} else if (key instanceof ShortWritable) {
ShortWritable swKey = (ShortWritable)key;
keyValue = swKey.get();
} else if (key instanceof IntWritable) {
IntWritable iwKey = (IntWritable)key;
keyValue = iwKey.get();
} else if (key instanceof LongWritable) {
LongWritable lwKey = (LongWritable)key;
keyValue = lwKey.get();
} else if (key instanceof TimestampWritable) {
TimestampWritable twKey = (TimestampWritable)key;
keyValue = twKey.getTimestamp();
} else if (key instanceof DoubleWritable) {
DoubleWritable dwKey = (DoubleWritable)key;
keyValue = dwKey.get();
} else if (key instanceof FloatWritable) {
FloatWritable fwKey = (FloatWritable)key;
keyValue = fwKey.get();
} else if (key instanceof BooleanWritable) {
BooleanWritable bwKey = (BooleanWritable)key;
keyValue = bwKey.get();
} else if (key instanceof HiveDecimalWritable) {
HiveDecimalWritable hdwKey = (HiveDecimalWritable)key;
keyValue = hdwKey.getHiveDecimal();
} else {
Assert.fail(String.format("Not implemented key output type %s: %s",
key.getClass().getName(), key));
}
String keyValueAsString = String.format("%s", keyValue);
assertTrue(expected.containsKey(keyValue));
Object expectedValue = expected.get(keyValue);
Object value = fields[1];
Validator validator = getValidator(aggregateName);
validator.validate(keyValueAsString, expectedValue, new Object[] {value});
keys.add(keyValue);
}
private FakeCaptureOutputOperator.OutputInspector init(
String aggregateName, Map<Object,Object> expected, Set<Object> keys) {
this.aggregateName = aggregateName;
this.expected = expected;
this.keys = keys;
return this;
}
}.init(aggregateName, expected, keys));
for (VectorizedRowBatch unit: data) {
vgo.process(unit, 0);
}
vgo.close(false);
List<Object> outBatchList = out.getCapturedRows();
assertNotNull(outBatchList);
assertEquals(expected.size(), outBatchList.size());
assertEquals(expected.size(), keys.size());
}
public void testAggregateLongRepeats (
String aggregateName,
Long value,
int repeat,
int batchSize,
Object expected) throws HiveException {
FakeVectorRowBatchFromRepeats fdr = new FakeVectorRowBatchFromRepeats(
new Long[] {value}, repeat, batchSize);
testAggregateLongIterable (aggregateName, fdr, expected);
}
public HashMap<Object, Object> buildHashMap(Object... pairs) {
HashMap<Object, Object> map = new HashMap<Object, Object>();
for(int i = 0; i < pairs.length; i += 2) {
map.put(pairs[i], pairs[i+1]);
}
return map;
}
public void testAggregateStringKeyAggregate (
String aggregateName,
int batchSize,
Iterable<Object> list,
Iterable<Object> values,
HashMap<Object, Object> expected) throws HiveException {
@SuppressWarnings("unchecked")
FakeVectorRowBatchFromObjectIterables fdr = new FakeVectorRowBatchFromObjectIterables(
batchSize,
new String[] {"string", "long"},
list,
values);
testAggregateStringKeyIterable (aggregateName, fdr, TypeInfoFactory.longTypeInfo, expected);
}
public void testAggregateDoubleStringKeyAggregate (
String aggregateName,
int batchSize,
Iterable<Object> list,
Iterable<Object> values,
HashMap<Object, Object> expected) throws HiveException {
@SuppressWarnings("unchecked")
FakeVectorRowBatchFromObjectIterables fdr = new FakeVectorRowBatchFromObjectIterables(
batchSize,
new String[] {"string", "double"},
list,
values);
testAggregateStringKeyIterable (aggregateName, fdr, TypeInfoFactory.doubleTypeInfo, expected);
}
public void testAggregateLongKeyAggregate (
String aggregateName,
int batchSize,
List<Long> list,
Iterable<Long> values,
HashMap<Object, Object> expected) throws HiveException {
@SuppressWarnings("unchecked")
FakeVectorRowBatchFromLongIterables fdr = new FakeVectorRowBatchFromLongIterables(batchSize,
list, values);
testAggregateLongKeyIterable (aggregateName, fdr, expected);
}
public void testAggregateDecimal (
String typeName,
String aggregateName,
int batchSize,
Iterable<Object> values,
Object expected) throws HiveException {
@SuppressWarnings("unchecked")
FakeVectorRowBatchFromObjectIterables fdr = new FakeVectorRowBatchFromObjectIterables(
batchSize, new String[] {typeName}, values);
testAggregateDecimalIterable (aggregateName, fdr, expected);
}
public void testAggregateString (
String aggregateName,
int batchSize,
Iterable<Object> values,
Object expected) throws HiveException {
@SuppressWarnings("unchecked")
FakeVectorRowBatchFromObjectIterables fdr = new FakeVectorRowBatchFromObjectIterables(
batchSize, new String[] {"string"}, values);
testAggregateStringIterable (aggregateName, fdr, expected);
}
public void testAggregateDouble (
String aggregateName,
int batchSize,
Iterable<Object> values,
Object expected) throws HiveException {
@SuppressWarnings("unchecked")
FakeVectorRowBatchFromObjectIterables fdr = new FakeVectorRowBatchFromObjectIterables(
batchSize, new String[] {"double"}, values);
testAggregateDoubleIterable (aggregateName, fdr, expected);
}
public void testAggregateLongAggregate (
String aggregateName,
int batchSize,
Iterable<Long> values,
Object expected) throws HiveException {
@SuppressWarnings("unchecked")
FakeVectorRowBatchFromLongIterables fdr = new FakeVectorRowBatchFromLongIterables(batchSize,
values);
testAggregateLongIterable (aggregateName, fdr, expected);
}
public void testAggregateCountStar (
int batchSize,
Iterable<Long> values,
Object expected) throws HiveException {
@SuppressWarnings("unchecked")
FakeVectorRowBatchFromLongIterables fdr = new FakeVectorRowBatchFromLongIterables(batchSize,
values);
testAggregateCountStarIterable (fdr, expected);
}
public void testAggregateCountReduce (
int batchSize,
Iterable<Long> values,
Object expected) throws HiveException {
@SuppressWarnings("unchecked")
FakeVectorRowBatchFromLongIterables fdr = new FakeVectorRowBatchFromLongIterables(batchSize,
values);
testAggregateCountReduceIterable (fdr, expected);
}
public static interface Validator {
void validate (String key, Object expected, Object result);
};
public static class ValueValidator implements Validator {
@Override
public void validate(String key, Object expected, Object result) {
assertEquals(true, result instanceof Object[]);
Object[] arr = (Object[]) result;
assertEquals(1, arr.length);
if (expected == null) {
assertEquals(key, null, arr[0]);
} else if (arr[0] instanceof LongWritable) {
LongWritable lw = (LongWritable) arr[0];
assertEquals(key, expected, lw.get());
} else if (arr[0] instanceof Text) {
Text tx = (Text) arr[0];
String sbw = tx.toString();
assertEquals(key, expected, sbw);
} else if (arr[0] instanceof DoubleWritable) {
DoubleWritable dw = (DoubleWritable) arr[0];
assertEquals (key, expected, dw.get());
} else if (arr[0] instanceof Double) {
assertEquals (key, expected, arr[0]);
} else if (arr[0] instanceof Long) {
assertEquals (key, expected, arr[0]);
} else if (arr[0] instanceof HiveDecimalWritable) {
HiveDecimalWritable hdw = (HiveDecimalWritable) arr[0];
HiveDecimal hd = hdw.getHiveDecimal();
HiveDecimal expectedDec = (HiveDecimal)expected;
assertEquals (key, expectedDec, hd);
} else if (arr[0] instanceof HiveDecimal) {
HiveDecimal hd = (HiveDecimal) arr[0];
HiveDecimal expectedDec = (HiveDecimal)expected;
assertEquals (key, expectedDec, hd);
} else {
Assert.fail("Unsupported result type: " + arr[0].getClass().getName());
}
}
}
public static class AvgValidator implements Validator {
@Override
public void validate(String key, Object expected, Object result) {
Object[] arr = (Object[]) result;
assertEquals (1, arr.length);
if (expected == null) {
assertEquals(key, null, arr[0]);
} else {
assertEquals (true, arr[0] instanceof Object[]);
Object[] vals = (Object[]) arr[0];
assertEquals (2, vals.length);
assertEquals (true, vals[0] instanceof LongWritable);
LongWritable lw = (LongWritable) vals[0];
assertFalse (lw.get() == 0L);
if (vals[1] instanceof DoubleWritable) {
DoubleWritable dw = (DoubleWritable) vals[1];
assertEquals (key, expected, dw.get() / lw.get());
} else if (vals[1] instanceof HiveDecimalWritable) {
HiveDecimalWritable hdw = (HiveDecimalWritable) vals[1];
assertEquals (key, expected, hdw.getHiveDecimal().divide(HiveDecimal.create(lw.get())));
}
}
}
}
public abstract static class BaseVarianceValidator implements Validator {
abstract void validateVariance (String key,
double expected, long cnt, double sum, double variance);
@Override
public void validate(String key, Object expected, Object result) {
Object[] arr = (Object[]) result;
assertEquals (1, arr.length);
if (expected == null) {
assertEquals(null, arr[0]);
} else {
assertEquals (true, arr[0] instanceof Object[]);
Object[] vals = (Object[]) arr[0];
assertEquals (3, vals.length);
assertEquals (true, vals[0] instanceof LongWritable);
assertEquals (true, vals[1] instanceof DoubleWritable);
assertEquals (true, vals[2] instanceof DoubleWritable);
LongWritable cnt = (LongWritable) vals[0];
DoubleWritable sum = (DoubleWritable) vals[1];
DoubleWritable var = (DoubleWritable) vals[2];
assertTrue (1 <= cnt.get());
validateVariance (key, (Double) expected, cnt.get(), sum.get(), var.get());
}
}
}
public static class VarianceValidator extends BaseVarianceValidator {
@Override
void validateVariance(String key, double expected, long cnt, double sum, double variance) {
assertEquals (key, expected, variance /cnt, 0.0);
}
}
public static class VarianceSampValidator extends BaseVarianceValidator {
@Override
void validateVariance(String key, double expected, long cnt, double sum, double variance) {
assertEquals (key, expected, variance /(cnt-1), 0.0);
}
}
public static class StdValidator extends BaseVarianceValidator {
@Override
void validateVariance(String key, double expected, long cnt, double sum, double variance) {
assertEquals (key, expected, Math.sqrt(variance / cnt), 0.0);
}
}
public static class StdSampValidator extends BaseVarianceValidator {
@Override
void validateVariance(String key, double expected, long cnt, double sum, double variance) {
assertEquals (key, expected, Math.sqrt(variance / (cnt-1)), 0.0);
}
}
private static Object[][] validators = {
{"count", ValueValidator.class},
{"min", ValueValidator.class},
{"max", ValueValidator.class},
{"sum", ValueValidator.class},
{"avg", AvgValidator.class},
{"variance", VarianceValidator.class},
{"var_pop", VarianceValidator.class},
{"var_samp", VarianceSampValidator.class},
{"std", StdValidator.class},
{"stddev", StdValidator.class},
{"stddev_pop", StdValidator.class},
{"stddev_samp", StdSampValidator.class},
};
public static Validator getValidator(String aggregate) throws HiveException {
try
{
for (Object[] v: validators) {
if (aggregate.equalsIgnoreCase((String) v[0])) {
@SuppressWarnings("unchecked")
Class<? extends Validator> c = (Class<? extends Validator>) v[1];
Constructor<? extends Validator> ctr = c.getConstructor();
return ctr.newInstance();
}
}
}catch(Exception e) {
throw new HiveException(e);
}
throw new HiveException("Missing validator for aggregate: " + aggregate);
}
public void testAggregateCountStarIterable (
Iterable<VectorizedRowBatch> data,
Object expected) throws HiveException {
List<String> mapColumnNames = new ArrayList<String>();
mapColumnNames.add("A");
VectorizationContext ctx = new VectorizationContext("name", mapColumnNames);
GroupByDesc desc = buildGroupByDescCountStar (ctx);
((VectorGroupByDesc) desc.getVectorDesc()).setProcessingMode(ProcessingMode.HASH);
CompilationOpContext cCtx = new CompilationOpContext();
Operator<? extends OperatorDesc> groupByOp = OperatorFactory.get(cCtx, desc);
VectorGroupByOperator vgo =
(VectorGroupByOperator) Vectorizer.vectorizeGroupByOperator(groupByOp, ctx);
FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(cCtx, vgo);
vgo.initialize(hconf, null);
for (VectorizedRowBatch unit: data) {
vgo.process(unit, 0);
}
vgo.close(false);
List<Object> outBatchList = out.getCapturedRows();
assertNotNull(outBatchList);
assertEquals(1, outBatchList.size());
Object result = outBatchList.get(0);
Validator validator = getValidator("count");
validator.validate("_total", expected, result);
}
public void testAggregateCountReduceIterable (
Iterable<VectorizedRowBatch> data,
Object expected) throws HiveException {
List<String> mapColumnNames = new ArrayList<String>();
mapColumnNames.add("A");
VectorizationContext ctx = new VectorizationContext("name", mapColumnNames);
GroupByDesc desc = buildGroupByDescType(ctx, "count", GenericUDAFEvaluator.Mode.FINAL, "A", TypeInfoFactory.longTypeInfo);
VectorGroupByDesc vectorDesc = (VectorGroupByDesc) desc.getVectorDesc();
vectorDesc.setProcessingMode(ProcessingMode.GLOBAL); // Use GLOBAL when no key for Reduce.
CompilationOpContext cCtx = new CompilationOpContext();
Operator<? extends OperatorDesc> groupByOp = OperatorFactory.get(cCtx, desc);
VectorGroupByOperator vgo =
(VectorGroupByOperator) Vectorizer.vectorizeGroupByOperator(groupByOp, ctx);
FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(cCtx, vgo);
vgo.initialize(hconf, null);
for (VectorizedRowBatch unit: data) {
vgo.process(unit, 0);
}
vgo.close(false);
List<Object> outBatchList = out.getCapturedRows();
assertNotNull(outBatchList);
assertEquals(1, outBatchList.size());
Object result = outBatchList.get(0);
Validator validator = getValidator("count");
validator.validate("_total", expected, result);
}
public void testAggregateStringIterable (
String aggregateName,
Iterable<VectorizedRowBatch> data,
Object expected) throws HiveException {
List<String> mapColumnNames = new ArrayList<String>();
mapColumnNames.add("A");
VectorizationContext ctx = new VectorizationContext("name", mapColumnNames);
GroupByDesc desc = buildGroupByDescType(ctx, aggregateName, GenericUDAFEvaluator.Mode.PARTIAL1, "A",
TypeInfoFactory.stringTypeInfo);
CompilationOpContext cCtx = new CompilationOpContext();
Operator<? extends OperatorDesc> groupByOp = OperatorFactory.get(cCtx, desc);
VectorGroupByOperator vgo =
(VectorGroupByOperator) Vectorizer.vectorizeGroupByOperator(groupByOp, ctx);
FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(cCtx, vgo);
vgo.initialize(hconf, null);
for (VectorizedRowBatch unit: data) {
vgo.process(unit, 0);
}
vgo.close(false);
List<Object> outBatchList = out.getCapturedRows();
assertNotNull(outBatchList);
assertEquals(1, outBatchList.size());
Object result = outBatchList.get(0);
Validator validator = getValidator(aggregateName);
validator.validate("_total", expected, result);
}
public void testAggregateDecimalIterable (
String aggregateName,
Iterable<VectorizedRowBatch> data,
Object expected) throws HiveException {
List<String> mapColumnNames = new ArrayList<String>();
mapColumnNames.add("A");
VectorizationContext ctx = new VectorizationContext("name", mapColumnNames);
GroupByDesc desc =
buildGroupByDescType(ctx, aggregateName, GenericUDAFEvaluator.Mode.PARTIAL1, "A", TypeInfoFactory.getDecimalTypeInfo(30, 4));
CompilationOpContext cCtx = new CompilationOpContext();
Operator<? extends OperatorDesc> groupByOp = OperatorFactory.get(cCtx, desc);
VectorGroupByOperator vgo =
(VectorGroupByOperator) Vectorizer.vectorizeGroupByOperator(groupByOp, ctx);
FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(cCtx, vgo);
vgo.initialize(hconf, null);
for (VectorizedRowBatch unit : data) {
vgo.process(unit, 0);
}
vgo.close(false);
List<Object> outBatchList = out.getCapturedRows();
assertNotNull(outBatchList);
assertEquals(1, outBatchList.size());
Object result = outBatchList.get(0);
Validator validator = getValidator(aggregateName);
validator.validate("_total", expected, result);
}
public void testAggregateDoubleIterable (
String aggregateName,
Iterable<VectorizedRowBatch> data,
Object expected) throws HiveException {
List<String> mapColumnNames = new ArrayList<String>();
mapColumnNames.add("A");
VectorizationContext ctx = new VectorizationContext("name", mapColumnNames);
GroupByDesc desc = buildGroupByDescType (ctx, aggregateName, GenericUDAFEvaluator.Mode.PARTIAL1, "A",
TypeInfoFactory.doubleTypeInfo);
CompilationOpContext cCtx = new CompilationOpContext();
Operator<? extends OperatorDesc> groupByOp = OperatorFactory.get(cCtx, desc);
VectorGroupByOperator vgo =
(VectorGroupByOperator) Vectorizer.vectorizeGroupByOperator(groupByOp, ctx);
FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(cCtx, vgo);
vgo.initialize(hconf, null);
for (VectorizedRowBatch unit: data) {
vgo.process(unit, 0);
}
vgo.close(false);
List<Object> outBatchList = out.getCapturedRows();
assertNotNull(outBatchList);
assertEquals(1, outBatchList.size());
Object result = outBatchList.get(0);
Validator validator = getValidator(aggregateName);
validator.validate("_total", expected, result);
}
public void testAggregateLongIterable (
String aggregateName,
Iterable<VectorizedRowBatch> data,
Object expected) throws HiveException {
List<String> mapColumnNames = new ArrayList<String>();
mapColumnNames.add("A");
VectorizationContext ctx = new VectorizationContext("name", mapColumnNames);
GroupByDesc desc = buildGroupByDescType(ctx, aggregateName, GenericUDAFEvaluator.Mode.PARTIAL1, "A", TypeInfoFactory.longTypeInfo);
CompilationOpContext cCtx = new CompilationOpContext();
Operator<? extends OperatorDesc> groupByOp = OperatorFactory.get(cCtx, desc);
VectorGroupByOperator vgo =
(VectorGroupByOperator) Vectorizer.vectorizeGroupByOperator(groupByOp, ctx);
FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(cCtx, vgo);
vgo.initialize(null, null);
for (VectorizedRowBatch unit: data) {
vgo.process(unit, 0);
}
vgo.close(false);
List<Object> outBatchList = out.getCapturedRows();
assertNotNull(outBatchList);
assertEquals(1, outBatchList.size());
Object result = outBatchList.get(0);
Validator validator = getValidator(aggregateName);
validator.validate("_total", expected, result);
}
public void testAggregateLongKeyIterable (
String aggregateName,
Iterable<VectorizedRowBatch> data,
HashMap<Object,Object> expected) throws HiveException {
List<String> mapColumnNames = new ArrayList<String>();
mapColumnNames.add("Key");
mapColumnNames.add("Value");
VectorizationContext ctx = new VectorizationContext("name", mapColumnNames);
Set<Object> keys = new HashSet<Object>();
GroupByDesc desc = buildKeyGroupByDesc (ctx, aggregateName, "Value",
TypeInfoFactory.longTypeInfo, "Key", TypeInfoFactory.longTypeInfo);
CompilationOpContext cCtx = new CompilationOpContext();
Operator<? extends OperatorDesc> groupByOp = OperatorFactory.get(cCtx, desc);
VectorGroupByOperator vgo =
(VectorGroupByOperator) Vectorizer.vectorizeGroupByOperator(groupByOp, ctx);
FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(cCtx, vgo);
vgo.initialize(hconf, null);
out.setOutputInspector(new FakeCaptureOutputOperator.OutputInspector() {
private String aggregateName;
private HashMap<Object,Object> expected;
private Set<Object> keys;
@Override
public void inspectRow(Object row, int tag) throws HiveException {
assertTrue(row instanceof Object[]);
Object[] fields = (Object[]) row;
assertEquals(2, fields.length);
Object key = fields[0];
Long keyValue = null;
if (null != key) {
assertTrue(key instanceof LongWritable);
LongWritable lwKey = (LongWritable)key;
keyValue = lwKey.get();
}
assertTrue(expected.containsKey(keyValue));
String keyAsString = String.format("%s", key);
Object expectedValue = expected.get(keyValue);
Object value = fields[1];
Validator validator = getValidator(aggregateName);
validator.validate(keyAsString, expectedValue, new Object[] {value});
keys.add(keyValue);
}
private FakeCaptureOutputOperator.OutputInspector init(
String aggregateName, HashMap<Object,Object> expected, Set<Object> keys) {
this.aggregateName = aggregateName;
this.expected = expected;
this.keys = keys;
return this;
}
}.init(aggregateName, expected, keys));
for (VectorizedRowBatch unit: data) {
vgo.process(unit, 0);
}
vgo.close(false);
List<Object> outBatchList = out.getCapturedRows();
assertNotNull(outBatchList);
assertEquals(expected.size(), outBatchList.size());
assertEquals(expected.size(), keys.size());
}
public void testAggregateStringKeyIterable (
String aggregateName,
Iterable<VectorizedRowBatch> data,
TypeInfo dataTypeInfo,
HashMap<Object,Object> expected) throws HiveException {
List<String> mapColumnNames = new ArrayList<String>();
mapColumnNames.add("Key");
mapColumnNames.add("Value");
VectorizationContext ctx = new VectorizationContext("name", mapColumnNames);
Set<Object> keys = new HashSet<Object>();
GroupByDesc desc = buildKeyGroupByDesc (ctx, aggregateName, "Value",
dataTypeInfo, "Key", TypeInfoFactory.stringTypeInfo);
CompilationOpContext cCtx = new CompilationOpContext();
Operator<? extends OperatorDesc> groupByOp = OperatorFactory.get(cCtx, desc);
VectorGroupByOperator vgo =
(VectorGroupByOperator) Vectorizer.vectorizeGroupByOperator(groupByOp, ctx);
FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(cCtx, vgo);
vgo.initialize(hconf, null);
out.setOutputInspector(new FakeCaptureOutputOperator.OutputInspector() {
private int rowIndex;
private String aggregateName;
private HashMap<Object,Object> expected;
private Set<Object> keys;
@SuppressWarnings("deprecation")
@Override
public void inspectRow(Object row, int tag) throws HiveException {
assertTrue(row instanceof Object[]);
Object[] fields = (Object[]) row;
assertEquals(2, fields.length);
Object key = fields[0];
String keyValue = null;
if (null != key) {
assertTrue(key instanceof Text);
Text bwKey = (Text)key;
keyValue = bwKey.toString();
}
assertTrue(expected.containsKey(keyValue));
Object expectedValue = expected.get(keyValue);
Object value = fields[1];
Validator validator = getValidator(aggregateName);
String keyAsString = String.format("%s", key);
validator.validate(keyAsString, expectedValue, new Object[] {value});
keys.add(keyValue);
}
private FakeCaptureOutputOperator.OutputInspector init(
String aggregateName, HashMap<Object,Object> expected, Set<Object> keys) {
this.aggregateName = aggregateName;
this.expected = expected;
this.keys = keys;
return this;
}
}.init(aggregateName, expected, keys));
for (VectorizedRowBatch unit: data) {
vgo.process(unit, 0);
}
vgo.close(false);
List<Object> outBatchList = out.getCapturedRows();
assertNotNull(outBatchList);
assertEquals(expected.size(), outBatchList.size());
assertEquals(expected.size(), keys.size());
}
}