/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Arrays;
import org.apache.hadoop.hive.common.type.HiveDecimal;
import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector;
import org.apache.hadoop.hive.ql.exec.vector.ColumnVector;
import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector;
import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector;
import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector;
import org.apache.hadoop.hive.ql.exec.vector.TimestampColumnVector;
import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow;
import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor;
import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression.AggregationBuffer;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.plan.AggregationDesc;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFBloomFilter;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFBloomFilter.GenericUDAFBloomFilterEvaluator;
import org.apache.hadoop.hive.ql.util.JavaDataModel;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.Text;
import org.apache.hive.common.util.BloomFilter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class VectorUDAFBloomFilter extends VectorAggregateExpression {
private static final Logger LOG = LoggerFactory.getLogger(VectorUDAFBloomFilter.class);
private static final long serialVersionUID = 1L;
private VectorExpression inputExpression;
@Override
public VectorExpression inputExpression() {
return inputExpression;
}
private long expectedEntries = -1;
private ValueProcessor valueProcessor;
transient private int bitSetSize = -1;
transient private BytesWritable bw = new BytesWritable();
transient private ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
/**
* class for storing the current aggregate value.
*/
private static final class Aggregation implements AggregationBuffer {
private static final long serialVersionUID = 1L;
BloomFilter bf;
public Aggregation(long expectedEntries) {
bf = new BloomFilter(expectedEntries);
}
@Override
public int getVariableSize() {
throw new UnsupportedOperationException();
}
@Override
public void reset() {
bf.reset();
}
}
public VectorUDAFBloomFilter(VectorExpression inputExpression) {
this();
this.inputExpression = inputExpression;
// Instantiate the ValueProcessor based on the input type
VectorExpressionDescriptor.ArgumentType inputType =
VectorExpressionDescriptor.ArgumentType.fromHiveTypeName(inputExpression.getOutputType());
switch (inputType) {
case INT_FAMILY:
case DATE:
valueProcessor = new ValueProcessorLong();
break;
case FLOAT_FAMILY:
valueProcessor = new ValueProcessorDouble();
break;
case DECIMAL:
valueProcessor = new ValueProcessorDecimal();
break;
case STRING:
case CHAR:
case VARCHAR:
case STRING_FAMILY:
case BINARY:
valueProcessor = new ValueProcessorBytes();
break;
case TIMESTAMP:
valueProcessor = new ValueProcessorTimestamp();
break;
default:
throw new IllegalStateException("Unsupported type " + inputType);
}
}
public VectorUDAFBloomFilter() {
super();
}
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
if (expectedEntries < 0) {
throw new IllegalStateException("expectedEntries not initialized");
}
return new Aggregation(expectedEntries);
}
@Override
public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch)
throws HiveException {
inputExpression.evaluate(batch);
ColumnVector inputColumn = batch.cols[this.inputExpression.getOutputColumn()];
int batchSize = batch.size;
if (batchSize == 0) {
return;
}
Aggregation myagg = (Aggregation) agg;
if (inputColumn.isRepeating) {
if (inputColumn.noNulls) {
valueProcessor.processValue(myagg, inputColumn, 0);
}
return;
}
if (!batch.selectedInUse && inputColumn.noNulls) {
iterateNoSelectionNoNulls(myagg, inputColumn, batchSize);
}
else if (!batch.selectedInUse) {
iterateNoSelectionHasNulls(myagg, inputColumn, batchSize);
}
else if (inputColumn.noNulls){
iterateSelectionNoNulls(myagg, inputColumn, batchSize, batch.selected);
}
else {
iterateSelectionHasNulls(myagg, inputColumn, batchSize, batch.selected);
}
}
private void iterateNoSelectionNoNulls(
Aggregation myagg,
ColumnVector inputColumn,
int batchSize) {
for (int i=0; i< batchSize; ++i) {
valueProcessor.processValue(myagg, inputColumn, i);
}
}
private void iterateNoSelectionHasNulls(
Aggregation myagg,
ColumnVector inputColumn,
int batchSize) {
for (int i=0; i< batchSize; ++i) {
if (!inputColumn.isNull[i]) {
valueProcessor.processValue(myagg, inputColumn, i);
}
}
}
private void iterateSelectionNoNulls(
Aggregation myagg,
ColumnVector inputColumn,
int batchSize,
int[] selected) {
for (int j=0; j< batchSize; ++j) {
int i = selected[j];
valueProcessor.processValue(myagg, inputColumn, i);
}
}
private void iterateSelectionHasNulls(
Aggregation myagg,
ColumnVector inputColumn,
int batchSize,
int[] selected) {
for (int j=0; j< batchSize; ++j) {
int i = selected[j];
if (!inputColumn.isNull[i]) {
valueProcessor.processValue(myagg, inputColumn, i);
}
}
}
@Override
public void aggregateInputSelection(
VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex,
VectorizedRowBatch batch) throws HiveException {
int batchSize = batch.size;
if (batchSize == 0) {
return;
}
inputExpression.evaluate(batch);
ColumnVector inputColumn = batch.cols[this.inputExpression.getOutputColumn()];
if (inputColumn.noNulls) {
if (inputColumn.isRepeating) {
iterateNoNullsRepeatingWithAggregationSelection(
aggregationBufferSets, aggregateIndex,
inputColumn, batchSize);
} else {
if (batch.selectedInUse) {
iterateNoNullsSelectionWithAggregationSelection(
aggregationBufferSets, aggregateIndex,
inputColumn, batch.selected, batchSize);
} else {
iterateNoNullsWithAggregationSelection(
aggregationBufferSets, aggregateIndex,
inputColumn, batchSize);
}
}
} else {
if (inputColumn.isRepeating) {
// All nulls, no-op for min/max
} else {
if (batch.selectedInUse) {
iterateHasNullsSelectionWithAggregationSelection(
aggregationBufferSets, aggregateIndex,
inputColumn, batchSize, batch.selected);
} else {
iterateHasNullsWithAggregationSelection(
aggregationBufferSets, aggregateIndex,
inputColumn, batchSize);
}
}
}
}
private void iterateNoNullsRepeatingWithAggregationSelection(
VectorAggregationBufferRow[] aggregationBufferSets,
int aggregrateIndex,
ColumnVector inputColumn,
int batchSize) {
for (int i=0; i < batchSize; ++i) {
Aggregation myagg = getCurrentAggregationBuffer(
aggregationBufferSets,
aggregrateIndex,
i);
valueProcessor.processValue(myagg, inputColumn, 0);
}
}
private void iterateNoNullsSelectionWithAggregationSelection(
VectorAggregationBufferRow[] aggregationBufferSets,
int aggregrateIndex,
ColumnVector inputColumn,
int[] selection,
int batchSize) {
for (int i=0; i < batchSize; ++i) {
int row = selection[i];
Aggregation myagg = getCurrentAggregationBuffer(
aggregationBufferSets,
aggregrateIndex,
i);
valueProcessor.processValue(myagg, inputColumn, row);
}
}
private void iterateNoNullsWithAggregationSelection(
VectorAggregationBufferRow[] aggregationBufferSets,
int aggregrateIndex,
ColumnVector inputColumn,
int batchSize) {
for (int i=0; i < batchSize; ++i) {
Aggregation myagg = getCurrentAggregationBuffer(
aggregationBufferSets,
aggregrateIndex,
i);
valueProcessor.processValue(myagg, inputColumn, i);
}
}
private void iterateHasNullsSelectionWithAggregationSelection(
VectorAggregationBufferRow[] aggregationBufferSets,
int aggregrateIndex,
ColumnVector inputColumn,
int batchSize,
int[] selection) {
for (int i=0; i < batchSize; ++i) {
int row = selection[i];
if (!inputColumn.isNull[row]) {
Aggregation myagg = getCurrentAggregationBuffer(
aggregationBufferSets,
aggregrateIndex,
i);
valueProcessor.processValue(myagg, inputColumn, i);
}
}
}
private void iterateHasNullsWithAggregationSelection(
VectorAggregationBufferRow[] aggregationBufferSets,
int aggregrateIndex,
ColumnVector inputColumn,
int batchSize) {
for (int i=0; i < batchSize; ++i) {
if (!inputColumn.isNull[i]) {
Aggregation myagg = getCurrentAggregationBuffer(
aggregationBufferSets,
aggregrateIndex,
i);
valueProcessor.processValue(myagg, inputColumn, i);
}
}
}
private Aggregation getCurrentAggregationBuffer(
VectorAggregationBufferRow[] aggregationBufferSets,
int aggregrateIndex,
int row) {
VectorAggregationBufferRow mySet = aggregationBufferSets[row];
Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(aggregrateIndex);
return myagg;
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
agg.reset();
}
@Override
public Object evaluateOutput(AggregationBuffer agg) throws HiveException {
try {
Aggregation bfAgg = (Aggregation) agg;
byteStream.reset();
BloomFilter.serialize(byteStream, bfAgg.bf);
byte[] bytes = byteStream.toByteArray();
bw.set(bytes, 0, bytes.length);
return bw;
} catch (IOException err) {
throw new HiveException("Error encountered while serializing bloomfilter", err);
}
}
@Override
public ObjectInspector getOutputObjectInspector() {
return PrimitiveObjectInspectorFactory.writableBinaryObjectInspector;
}
@Override
public long getAggregationBufferFixedSize() {
if (bitSetSize < 0) {
// Not pretty, but we need a way to get the size
try {
Aggregation agg = (Aggregation) getNewAggregationBuffer();
bitSetSize = agg.bf.getBitSet().length;
} catch (Exception e) {
throw new RuntimeException("Unexpected error while creating AggregationBuffer", e);
}
}
// BloomFilter: object(BitSet: object(data: long[]), numBits: int, numHashFunctions: int)
JavaDataModel model = JavaDataModel.get();
long bloomFilterSize = JavaDataModel.alignUp(model.object() + model.lengthForLongArrayOfSize(bitSetSize),
model.memoryAlign());
return JavaDataModel.alignUp(
model.object() + bloomFilterSize + model.primitive1() + model.primitive1(),
model.memoryAlign());
}
@Override
public void init(AggregationDesc desc) throws HiveException {
GenericUDAFBloomFilterEvaluator udafBloomFilter =
(GenericUDAFBloomFilterEvaluator) desc.getGenericUDAFEvaluator();
expectedEntries = udafBloomFilter.getExpectedEntries();
}
public VectorExpression getInputExpression() {
return inputExpression;
}
public void setInputExpression(VectorExpression inputExpression) {
this.inputExpression = inputExpression;
}
public long getExpectedEntries() {
return expectedEntries;
}
public void setExpectedEntries(long expectedEntries) {
this.expectedEntries = expectedEntries;
}
// Type-specific handling done here
private static abstract class ValueProcessor {
abstract protected void processValue(Aggregation myagg, ColumnVector inputColumn, int index);
}
//
// Type-specific implementations
//
public static class ValueProcessorBytes extends ValueProcessor {
@Override
protected void processValue(Aggregation myagg, ColumnVector columnVector, int i) {
BytesColumnVector inputColumn = (BytesColumnVector) columnVector;
myagg.bf.addBytes(inputColumn.vector[i], inputColumn.start[i], inputColumn.length[i]);
}
}
public static class ValueProcessorLong extends ValueProcessor {
@Override
protected void processValue(Aggregation myagg, ColumnVector columnVector, int i) {
LongColumnVector inputColumn = (LongColumnVector) columnVector;
myagg.bf.addLong(inputColumn.vector[i]);
}
}
public static class ValueProcessorDouble extends ValueProcessor {
@Override
protected void processValue(Aggregation myagg, ColumnVector columnVector, int i) {
DoubleColumnVector inputColumn = (DoubleColumnVector) columnVector;
myagg.bf.addDouble(inputColumn.vector[i]);
}
}
public static class ValueProcessorDecimal extends ValueProcessor {
private byte[] scratchBuffer = new byte[HiveDecimal.SCRATCH_BUFFER_LEN_TO_BYTES];
@Override
protected void processValue(Aggregation myagg, ColumnVector columnVector, int i) {
DecimalColumnVector inputColumn = (DecimalColumnVector) columnVector;
int startIdx = inputColumn.vector[i].toBytes(scratchBuffer);
myagg.bf.addBytes(scratchBuffer, startIdx, scratchBuffer.length - startIdx);
}
}
public static class ValueProcessorTimestamp extends ValueProcessor {
@Override
protected void processValue(Aggregation myagg, ColumnVector columnVector, int i) {
TimestampColumnVector inputColumn = (TimestampColumnVector) columnVector;
myagg.bf.addLong(inputColumn.time[i]);
}
}
}