/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hive.ql.udf.generic;
import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.io.LongWritable;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
import jersey.repackaged.com.google.common.collect.Lists;
@RunWith(Parameterized.class)
public class TestGenericUDAFBinarySetFunctions {
private List<Object[]> rowSet;
@Parameters(name = "{0}")
public static List<Object[]> getParameters() {
List<Object[]> ret = new ArrayList<>();
ret.add(new Object[] { "seq/seq", RowSetGenerator.generate(10,
new RowSetGenerator.DoubleSequence(0), new RowSetGenerator.DoubleSequence(0)) });
ret.add(new Object[] { "seq/ones", RowSetGenerator.generate(10,
new RowSetGenerator.DoubleSequence(0), new RowSetGenerator.ConstantSequence(1.0)) });
ret.add(new Object[] { "ones/seq", RowSetGenerator.generate(10,
new RowSetGenerator.ConstantSequence(1.0), new RowSetGenerator.DoubleSequence(0)) });
ret.add(new Object[] { "empty", RowSetGenerator.generate(0,
new RowSetGenerator.DoubleSequence(0), new RowSetGenerator.DoubleSequence(0)) });
ret.add(new Object[] { "lonely", RowSetGenerator.generate(1,
new RowSetGenerator.DoubleSequence(10), new RowSetGenerator.DoubleSequence(10)) });
ret.add(new Object[] { "seq/seq+10", RowSetGenerator.generate(10,
new RowSetGenerator.DoubleSequence(0), new RowSetGenerator.DoubleSequence(10)) });
ret.add(new Object[] { "seq/null", RowSetGenerator.generate(10,
new RowSetGenerator.DoubleSequence(0), new RowSetGenerator.ConstantSequence(null)) });
ret.add(new Object[] { "null/seq0", RowSetGenerator.generate(10,
new RowSetGenerator.ConstantSequence(null), new RowSetGenerator.DoubleSequence(0)) });
return ret;
}
public static class GenericUDAFExecutor {
private GenericUDAFResolver2 evaluatorFactory;
private GenericUDAFParameterInfo info;
private ObjectInspector[] partialOIs;
public GenericUDAFExecutor(GenericUDAFResolver2 evaluatorFactory, GenericUDAFParameterInfo info)
throws Exception {
this.evaluatorFactory = evaluatorFactory;
this.info = info;
GenericUDAFEvaluator eval0 = evaluatorFactory.getEvaluator(info);
partialOIs = new ObjectInspector[] {
eval0.init(GenericUDAFEvaluator.Mode.PARTIAL1, info.getParameterObjectInspectors()) };
}
List<Object> run(List<Object[]> values) throws Exception {
Object r1 = runComplete(values);
Object r2 = runPartialFinal(values);
Object r3 = runPartial2Final(values);
return Lists.newArrayList(r1, r2, r3);
}
private Object runComplete(List<Object[]> values) throws SemanticException, HiveException {
GenericUDAFEvaluator eval = evaluatorFactory.getEvaluator(info);
eval.init(GenericUDAFEvaluator.Mode.COMPLETE, info.getParameterObjectInspectors());
AggregationBuffer agg = eval.getNewAggregationBuffer();
for (Object[] parameters : values) {
eval.iterate(agg, parameters);
}
return eval.terminate(agg);
}
private Object runPartialFinal(List<Object[]> values) throws Exception {
GenericUDAFEvaluator eval = evaluatorFactory.getEvaluator(info);
eval.init(GenericUDAFEvaluator.Mode.FINAL, partialOIs);
AggregationBuffer buf = eval.getNewAggregationBuffer();
for (Object partialResult : runPartial1(values)) {
eval.merge(buf, partialResult);
}
return eval.terminate(buf);
}
private Object runPartial2Final(List<Object[]> values) throws Exception {
GenericUDAFEvaluator eval = evaluatorFactory.getEvaluator(info);
eval.init(GenericUDAFEvaluator.Mode.FINAL, partialOIs);
AggregationBuffer buf = eval.getNewAggregationBuffer();
for (Object partialResult : runPartial2(runPartial1(values))) {
eval.merge(buf, partialResult);
}
return eval.terminate(buf);
}
private List<Object> runPartial1(List<Object[]> values) throws Exception {
List<Object> ret = new ArrayList<>();
int batchSize = 1;
Iterator<Object[]> iter = values.iterator();
do {
GenericUDAFEvaluator eval = evaluatorFactory.getEvaluator(info);
eval.init(GenericUDAFEvaluator.Mode.PARTIAL1, info.getParameterObjectInspectors());
AggregationBuffer buf = eval.getNewAggregationBuffer();
for (int i = 0; i < batchSize - 1 && iter.hasNext(); i++) {
eval.iterate(buf, iter.next());
}
batchSize <<= 1;
ret.add(eval.terminatePartial(buf));
// back-check to force at least 1 output; and this should have a partial which is empty
} while (iter.hasNext());
return ret;
}
private List<Object> runPartial2(List<Object> values) throws Exception {
List<Object> ret = new ArrayList<>();
int batchSize = 1;
Iterator<Object> iter = values.iterator();
do {
GenericUDAFEvaluator eval = evaluatorFactory.getEvaluator(info);
eval.init(GenericUDAFEvaluator.Mode.PARTIAL2, partialOIs);
AggregationBuffer buf = eval.getNewAggregationBuffer();
for (int i = 0; i < batchSize - 1 && iter.hasNext(); i++) {
eval.merge(buf, iter.next());
}
batchSize <<= 1;
ret.add(eval.terminatePartial(buf));
// back-check to force at least 1 output; and this should have a partial which is empty
} while (iter.hasNext());
return ret;
}
}
public static class RowSetGenerator {
public static interface FieldGenerator {
public Object apply(int rowIndex);
}
public static class ConstantSequence implements FieldGenerator {
private Object constant;
public ConstantSequence(Object constant) {
this.constant = constant;
}
@Override
public Object apply(int rowIndex) {
return constant;
}
}
public static class DoubleSequence implements FieldGenerator {
private int offset;
public DoubleSequence(int offset) {
this.offset = offset;
}
@Override
public Object apply(int rowIndex) {
double d = rowIndex + offset;
return d;
}
}
public static List<Object[]> generate(int numRows, FieldGenerator... generators) {
ArrayList<Object[]> ret = new ArrayList<>(numRows);
for (int rowIdx = 0; rowIdx < numRows; rowIdx++) {
ArrayList<Object> row = new ArrayList<>();
for (FieldGenerator g : generators) {
row.add(g.apply(rowIdx));
}
ret.add(row.toArray());
}
return ret;
}
}
public TestGenericUDAFBinarySetFunctions(String label, List<Object[]> rowSet) {
this.rowSet = rowSet;
}
@Test
public void regr_count() throws Exception {
RegrIntermediate expected = RegrIntermediate.computeFor(rowSet);
validateUDAF(expected.count(), new GenericUDAFBinarySetFunctions.RegrCount());
}
@Test
public void regr_sxx() throws Exception {
RegrIntermediate expected = RegrIntermediate.computeFor(rowSet);
validateUDAF(expected.sxx(), new GenericUDAFBinarySetFunctions.RegrSXX());
}
@Test
public void regr_syy() throws Exception {
RegrIntermediate expected = RegrIntermediate.computeFor(rowSet);
validateUDAF(expected.syy(), new GenericUDAFBinarySetFunctions.RegrSYY());
}
@Test
public void regr_sxy() throws Exception {
RegrIntermediate expected = RegrIntermediate.computeFor(rowSet);
validateUDAF(expected.sxy(), new GenericUDAFBinarySetFunctions.RegrSXY());
}
@Test
public void regr_avgx() throws Exception {
RegrIntermediate expected = RegrIntermediate.computeFor(rowSet);
validateUDAF(expected.avgx(), new GenericUDAFBinarySetFunctions.RegrAvgX());
}
@Test
public void regr_avgy() throws Exception {
RegrIntermediate expected = RegrIntermediate.computeFor(rowSet);
validateUDAF(expected.avgy(), new GenericUDAFBinarySetFunctions.RegrAvgY());
}
@Test
public void regr_slope() throws Exception {
RegrIntermediate expected = RegrIntermediate.computeFor(rowSet);
validateUDAF(expected.slope(), new GenericUDAFBinarySetFunctions.RegrSlope());
}
@Test
public void regr_r2() throws Exception {
RegrIntermediate expected = RegrIntermediate.computeFor(rowSet);
validateUDAF(expected.r2(), new GenericUDAFBinarySetFunctions.RegrR2());
}
@Test
public void regr_intercept() throws Exception {
RegrIntermediate expected = RegrIntermediate.computeFor(rowSet);
validateUDAF(expected.intercept(), new GenericUDAFBinarySetFunctions.RegrIntercept());
}
@Test
public void corr() throws Exception {
RegrIntermediate expected = RegrIntermediate.computeFor(rowSet);
validateUDAF(expected.corr(), new GenericUDAFCorrelation());
}
@Test
public void covar_pop() throws Exception {
RegrIntermediate expected = RegrIntermediate.computeFor(rowSet);
validateUDAF(expected.covar_pop(), new GenericUDAFCovariance());
}
@Test
public void covar_samp() throws Exception {
RegrIntermediate expected = RegrIntermediate.computeFor(rowSet);
validateUDAF(expected.covar_samp(), new GenericUDAFCovarianceSample());
}
private void validateUDAF(Double expectedResult, GenericUDAFResolver2 udaf) throws Exception {
ObjectInspector[] params =
new ObjectInspector[] { javaDoubleObjectInspector, javaDoubleObjectInspector };
GenericUDAFParameterInfo gpi = new SimpleGenericUDAFParameterInfo(params, false, false, false);
GenericUDAFExecutor executor = new GenericUDAFExecutor(udaf, gpi);
List<Object> values = executor.run(rowSet);
if (expectedResult == null) {
for (Object v : values) {
assertNull(v);
}
} else {
for (Object v : values) {
if (v instanceof DoubleWritable) {
assertEquals(expectedResult, ((DoubleWritable) v).get(), 1e-10);
} else {
assertEquals(expectedResult, ((LongWritable) v).get(), 1e-10);
}
}
}
}
static class RegrIntermediate {
public double sum_x2, sum_y2;
public double sum_x, sum_y;
public double sum_xy;
public double n;
public void add(Double y, Double x) {
if (x == null || y == null) {
return;
}
sum_x2 += x * x;
sum_y2 += y * y;
sum_x += x;
sum_y += y;
sum_xy += x * y;
n++;
}
public Double intercept() {
double xx = n * sum_x2 - sum_x * sum_x;
if (n == 0 || xx == 0.0d)
return null;
return (sum_y * sum_x2 - sum_x * sum_xy) / xx;
}
public Double sxy() {
if (n == 0)
return null;
return sum_xy - sum_x * sum_y / n;
}
public Double covar_pop() {
if (n == 0)
return null;
return (sum_xy - sum_x * sum_y / n) / n;
}
public Double covar_samp() {
if (n <= 1)
return null;
return (sum_xy - sum_x * sum_y / n) / (n - 1);
}
public Double corr() {
double xx = n * sum_x2 - sum_x * sum_x;
double yy = n * sum_y2 - sum_y * sum_y;
if (n == 0 || xx == 0.0d || yy == 0.0d)
return null;
double c = n * sum_xy - sum_x * sum_y;
return Math.sqrt(c * c / xx / yy);
}
public Double r2() {
double xx = n * sum_x2 - sum_x * sum_x;
double yy = n * sum_y2 - sum_y * sum_y;
if (n == 0 || xx == 0.0d)
return null;
if (yy == 0.0d)
return 1.0d;
double c = n * sum_xy - sum_x * sum_y;
return c * c / xx / yy;
}
public Double slope() {
if (n == 0 || n * sum_x2 == sum_x * sum_x)
return null;
return (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x);
}
public Double avgx() {
if (n == 0)
return null;
return sum_x / n;
}
public Double avgy() {
if (n == 0)
return null;
return sum_y / n;
}
public Double count() {
return n;
}
public Double sxx() {
if (n == 0)
return null;
return sum_x2 - sum_x * sum_x / n;
}
public Double syy() {
if (n == 0)
return null;
return sum_y2 - sum_y * sum_y / n;
}
public static RegrIntermediate computeFor(List<Object[]> rows) {
RegrIntermediate ri = new RegrIntermediate();
for (Object[] objects : rows) {
ri.add((Double) objects[0], (Double) objects[1]);
}
return ri;
}
}
}