/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hive.ql.udf.generic;
import org.apache.hadoop.hive.common.type.HiveDecimal;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.SelectOperator;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.ColStatistics;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDescUtils;
import org.apache.hadoop.hive.ql.plan.Statistics;
import org.apache.hadoop.hive.ql.plan.Statistics.State;
import org.apache.hadoop.hive.serde2.io.DateWritable;
import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.*;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.Text;
import org.apache.hive.common.util.BloomFilter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.sql.Timestamp;
import java.util.List;
/**
* Generic UDF to generate Bloom Filter
*/
public class GenericUDAFBloomFilter implements GenericUDAFResolver2 {
private static final Logger LOG = LoggerFactory.getLogger(GenericUDAFBloomFilter.class);
@Override
public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) throws SemanticException {
return new GenericUDAFBloomFilterEvaluator();
}
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
return new GenericUDAFBloomFilterEvaluator();
}
/**
* GenericUDAFBloomFilterEvaluator - Evaluator class for BloomFilter
*/
public static class GenericUDAFBloomFilterEvaluator extends GenericUDAFEvaluator {
// Source operator to get the number of entries
private SelectOperator sourceOperator;
private long hintEntries = -1;
private long maxEntries = 0;
private long minEntries = 0;
private float factor = 1;
// ObjectInspector for input data.
private PrimitiveObjectInspector inputOI;
// Bloom filter rest
private ByteArrayOutputStream result = new ByteArrayOutputStream();
private transient byte[] scratchBuffer = new byte[HiveDecimal.SCRATCH_BUFFER_LEN_TO_BYTES];
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
super.init(m, parameters);
// Initialize input
if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
inputOI = (PrimitiveObjectInspector) parameters[0];
} else {
// Do nothing for other modes
}
// Output will be same in both partial or full aggregation modes.
// It will be a BloomFilter in ByteWritable
return PrimitiveObjectInspectorFactory.writableBinaryObjectInspector;
}
/**
* Class for storing the BloomFilter
*/
@AggregationType(estimable = true)
static class BloomFilterBuf extends AbstractAggregationBuffer {
BloomFilter bloomFilter;
public BloomFilterBuf(long expectedEntries, long maxEntries) {
if (expectedEntries > maxEntries) {
bloomFilter = new BloomFilter(1);
} else {
bloomFilter = new BloomFilter(expectedEntries);
}
}
@Override
public int estimate() {
return (int) bloomFilter.sizeInBytes();
}
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
((BloomFilterBuf)agg).bloomFilter.reset();
}
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
long expectedEntries = getExpectedEntries();
if (expectedEntries < 0) {
throw new IllegalStateException("BloomFilter expectedEntries not initialized");
}
BloomFilterBuf buf = new BloomFilterBuf(expectedEntries, maxEntries);
reset(buf);
return buf;
}
@Override
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
if (parameters == null || parameters[0] == null) {
// 2nd condition occurs when the input has 0 rows (possible due to
// filtering, joins etc).
return;
}
BloomFilter bf = ((BloomFilterBuf)agg).bloomFilter;
// Add the expression into the BloomFilter
switch (inputOI.getPrimitiveCategory()) {
case BOOLEAN:
boolean vBoolean = ((BooleanObjectInspector)inputOI).get(parameters[0]);
bf.addLong(vBoolean ? 1 : 0);
break;
case BYTE:
byte vByte = ((ByteObjectInspector)inputOI).get(parameters[0]);
bf.addLong(vByte);
break;
case SHORT:
short vShort = ((ShortObjectInspector)inputOI).get(parameters[0]);
bf.addLong(vShort);
break;
case INT:
int vInt = ((IntObjectInspector)inputOI).get(parameters[0]);
bf.addLong(vInt);
break;
case LONG:
long vLong = ((LongObjectInspector)inputOI).get(parameters[0]);
bf.addLong(vLong);
break;
case FLOAT:
float vFloat = ((FloatObjectInspector)inputOI).get(parameters[0]);
bf.addDouble(vFloat);
break;
case DOUBLE:
double vDouble = ((DoubleObjectInspector)inputOI).get(parameters[0]);
bf.addDouble(vDouble);
break;
case DECIMAL:
HiveDecimalWritable vDecimal = ((HiveDecimalObjectInspector)inputOI).
getPrimitiveWritableObject(parameters[0]);
int startIdx = vDecimal.toBytes(scratchBuffer);
bf.addBytes(scratchBuffer, startIdx, scratchBuffer.length - startIdx);
break;
case DATE:
DateWritable vDate = ((DateObjectInspector)inputOI).
getPrimitiveWritableObject(parameters[0]);
bf.addLong(vDate.getDays());
break;
case TIMESTAMP:
Timestamp vTimeStamp = ((TimestampObjectInspector)inputOI).
getPrimitiveJavaObject(parameters[0]);
bf.addLong(vTimeStamp.getTime());
break;
case CHAR:
Text vChar = ((HiveCharObjectInspector)inputOI).
getPrimitiveWritableObject(parameters[0]).getStrippedValue();
bf.addBytes(vChar.getBytes(), 0, vChar.getLength());
break;
case VARCHAR:
Text vVarChar = ((HiveVarcharObjectInspector)inputOI).
getPrimitiveWritableObject(parameters[0]).getTextValue();
bf.addBytes(vVarChar.getBytes(), 0, vVarChar.getLength());
break;
case STRING:
Text vString = ((StringObjectInspector)inputOI).
getPrimitiveWritableObject(parameters[0]);
bf.addBytes(vString.getBytes(), 0, vString.getLength());
break;
case BINARY:
BytesWritable vBytes = ((BinaryObjectInspector)inputOI).
getPrimitiveWritableObject(parameters[0]);
bf.addBytes(vBytes.getBytes(), 0, vBytes.getLength());
break;
default:
throw new UDFArgumentTypeException(0,
"Bad primitive category " + inputOI.getPrimitiveCategory());
}
}
@Override
public void merge(AggregationBuffer agg, Object partial) throws HiveException {
if (partial == null) {
return;
}
BytesWritable bytes = (BytesWritable) partial;
ByteArrayInputStream in = new ByteArrayInputStream(bytes.getBytes());
// Deserialze the bloomfilter
try {
BloomFilter bf = BloomFilter.deserialize(in);
((BloomFilterBuf)agg).bloomFilter.merge(bf);
} catch (IOException e) {
throw new HiveException(e);
}
}
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
result.reset();
try {
BloomFilter.serialize(result, ((BloomFilterBuf)agg).bloomFilter);
} catch (IOException e) {
throw new HiveException(e);
}
return new BytesWritable(result.toByteArray());
}
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
return terminate(agg);
}
public long getExpectedEntries() {
// If hint is provided use that size.
if (hintEntries > 0 )
return hintEntries;
long expectedEntries = -1;
if (sourceOperator != null && sourceOperator.getStatistics() != null) {
Statistics stats = sourceOperator.getStatistics();
expectedEntries = stats.getNumRows();
// Use NumDistinctValues if possible
switch (stats.getColumnStatsState()) {
case COMPLETE:
case PARTIAL:
// There should only be column in sourceOperator
List<ColStatistics> colStats = stats.getColumnStats();
ExprNodeColumnDesc colExpr = ExprNodeDescUtils.getColumnExpr(
sourceOperator.getConf().getColList().get(0));
if (colExpr != null
&& stats.getColumnStatisticsFromColName(colExpr.getColumn()) != null) {
long ndv = stats.getColumnStatisticsFromColName(colExpr.getColumn()).getCountDistint();
if (ndv > 0) {
expectedEntries = ndv;
}
}
break;
default:
break;
}
}
// Update expectedEntries based on factor and minEntries configurations
expectedEntries = (long) (expectedEntries * factor);
expectedEntries = expectedEntries > minEntries ? expectedEntries : minEntries;
return expectedEntries;
}
public Operator<?> getSourceOperator() {
return sourceOperator;
}
public void setSourceOperator(SelectOperator sourceOperator) {
this.sourceOperator = sourceOperator;
}
public void setHintEntries(long hintEntries) {
this.hintEntries = hintEntries;
}
public boolean hasHintEntries() {
return hintEntries != -1;
}
public void setMaxEntries(long maxEntries) {
this.maxEntries = maxEntries;
}
public void setMinEntries(long minEntries) {
this.minEntries = minEntries;
}
public long getMinEntries() {
return minEntries;
}
public void setFactor(float factor) {
this.factor = factor;
}
public float getFactor() {
return factor;
}
@Override
public String getExprString() {
return "expectedEntries=" + getExpectedEntries();
}
}
}