/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hive.ql.udf.generic;
import com.google.common.base.Preconditions;
import org.apache.hadoop.hive.common.type.HiveDecimal;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.io.ByteWritable;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
import org.apache.hadoop.hive.serde2.io.ShortWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveGrouping.NUMERIC_GROUP;
import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveGrouping.VOID_GROUP;
@Description(name = "width_bucket",
value = "_FUNC_(expr, min_value, max_value, num_buckets) - Returns an integer between 0 and num_buckets+1 by "
+ "mapping the expr into buckets defined by the range [min_value, max_value]",
extended = "Returns an integer between 0 and num_buckets+1 by "
+ "mapping expr into the ith equally sized bucket. Buckets are made by dividing [min_value, max_value] into "
+ "equally sized regions. If expr < min_value, return 1, if expr > max_value return num_buckets+1\n"
+ "Example: expr is an integer column withs values 1, 10, 20, 30.\n"
+ " > SELECT _FUNC_(expr, 5, 25, 4) FROM src;\n1\n1\n3\n5")
public class GenericUDFWidthBucket extends GenericUDF {
private transient ObjectInspector[] objectInspectors;
private transient ObjectInspector commonExprMinMaxOI;
private transient ObjectInspectorConverters.Converter epxrConverterOI;
private transient ObjectInspectorConverters.Converter minValueConverterOI;
private transient ObjectInspectorConverters.Converter maxValueConverterOI;
private final IntWritable output = new IntWritable();
@Override
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
this.objectInspectors = arguments;
checkArgsSize(arguments, 4, 4);
checkArgPrimitive(arguments, 0);
checkArgPrimitive(arguments, 1);
checkArgPrimitive(arguments, 2);
checkArgPrimitive(arguments, 3);
PrimitiveObjectInspector.PrimitiveCategory[] inputTypes = new PrimitiveObjectInspector.PrimitiveCategory[4];
checkArgGroups(arguments, 0, inputTypes, NUMERIC_GROUP, VOID_GROUP);
checkArgGroups(arguments, 1, inputTypes, NUMERIC_GROUP, VOID_GROUP);
checkArgGroups(arguments, 2, inputTypes, NUMERIC_GROUP, VOID_GROUP);
checkArgGroups(arguments, 3, inputTypes, NUMERIC_GROUP, VOID_GROUP);
TypeInfo exprTypeInfo = TypeInfoUtils.getTypeInfoFromObjectInspector(this.objectInspectors[0]);
TypeInfo minValueTypeInfo = TypeInfoUtils.getTypeInfoFromObjectInspector(this.objectInspectors[1]);
TypeInfo maxValueTypeInfo = TypeInfoUtils.getTypeInfoFromObjectInspector(this.objectInspectors[2]);
TypeInfo commonExprMinMaxTypeInfo = FunctionRegistry.getCommonClassForComparison(exprTypeInfo,
FunctionRegistry.getCommonClassForComparison(minValueTypeInfo, maxValueTypeInfo));
this.commonExprMinMaxOI = TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo(commonExprMinMaxTypeInfo);
this.epxrConverterOI = ObjectInspectorConverters.getConverter(this.objectInspectors[0], this.commonExprMinMaxOI);
this.minValueConverterOI = ObjectInspectorConverters.getConverter(this.objectInspectors[1], this.commonExprMinMaxOI);
this.maxValueConverterOI = ObjectInspectorConverters.getConverter(this.objectInspectors[2], this.commonExprMinMaxOI);
return PrimitiveObjectInspectorFactory.writableIntObjectInspector;
}
@Override
public Object evaluate(DeferredObject[] arguments) throws HiveException {
if (arguments[0].get() == null || arguments[1].get() == null || arguments[2].get() == null || arguments[3].get() == null) {
return null;
}
Object exprValue = this.epxrConverterOI.convert(arguments[0].get());
Object minValue = this.minValueConverterOI.convert(arguments[1].get());
Object maxValue = this.maxValueConverterOI.convert(arguments[2].get());
int numBuckets = PrimitiveObjectInspectorUtils.getInt(arguments[3].get(),
(PrimitiveObjectInspector) this.objectInspectors[3]);
switch (((PrimitiveObjectInspector) this.commonExprMinMaxOI).getPrimitiveCategory()) {
case SHORT:
return evaluate(((ShortWritable) exprValue).get(), ((ShortWritable) minValue).get(),
((ShortWritable) maxValue).get(), numBuckets);
case INT:
return evaluate(((IntWritable) exprValue).get(), ((IntWritable) minValue).get(),
((IntWritable) maxValue).get(), numBuckets);
case LONG:
return evaluate(((LongWritable) exprValue).get(), ((LongWritable) minValue).get(),
((LongWritable) maxValue).get(), numBuckets);
case FLOAT:
return evaluate(((FloatWritable) exprValue).get(), ((FloatWritable) minValue).get(),
((FloatWritable) maxValue).get(), numBuckets);
case DOUBLE:
return evaluate(((DoubleWritable) exprValue).get(), ((DoubleWritable) minValue).get(),
((DoubleWritable) maxValue).get(), numBuckets);
case DECIMAL:
return evaluate(((HiveDecimalWritable) exprValue).getHiveDecimal(),
((HiveDecimalWritable) minValue).getHiveDecimal(), ((HiveDecimalWritable) maxValue).getHiveDecimal(),
numBuckets);
case BYTE:
return evaluate(((ByteWritable) exprValue).get(), ((ByteWritable) minValue).get(),
((ByteWritable) maxValue).get(), numBuckets);
default:
throw new IllegalStateException(
"Error: width_bucket could not determine a common primitive type for all inputs");
}
}
private IntWritable evaluate(short exprValue, short minValue, short maxValue, int numBuckets) {
Preconditions.checkArgument(numBuckets > 0, "numBuckets in width_bucket function must be above 0");
Preconditions.checkArgument(maxValue != minValue, "maxValue cannot be equal to minValue in width_bucket function");
if (maxValue > minValue) {
if (exprValue < minValue) {
output.set(0);
} else if (exprValue >= maxValue) {
output.set(numBuckets + 1);
} else {
output.set((int) Math.floor((numBuckets * (exprValue - minValue) / (maxValue - minValue)) + 1));
}
} else {
if (exprValue > minValue) {
output.set(0);
} else if (exprValue <= maxValue) {
output.set(numBuckets + 1);
} else {
output.set((int) Math.floor((numBuckets * (minValue - exprValue) / (minValue - maxValue)) + 1));
}
}
return output;
}
private IntWritable evaluate(int exprValue, int minValue, int maxValue, int numBuckets) {
Preconditions.checkArgument(numBuckets > 0, "numBuckets in width_bucket function must be above 0");
Preconditions.checkArgument(maxValue != minValue, "maxValue cannot be equal to minValue in width_bucket function");
if (maxValue > minValue) {
if (exprValue < minValue) {
output.set(0);
} else if (exprValue >= maxValue) {
output.set(numBuckets + 1);
} else {
output.set((int) Math.floor((numBuckets * (exprValue - minValue) / (maxValue - minValue)) + 1));
}
} else {
if (exprValue > minValue) {
output.set(0);
} else if (exprValue <= maxValue) {
output.set(numBuckets + 1);
} else {
output.set((int) Math.floor((numBuckets * (minValue - exprValue) / (minValue - maxValue)) + 1));
}
}
return output;
}
private IntWritable evaluate(long exprValue, long minValue, long maxValue, int numBuckets) {
Preconditions.checkArgument(numBuckets > 0, "numBuckets in width_bucket function must be above 0");
Preconditions.checkArgument(maxValue != minValue, "maxValue cannot be equal to minValue in width_bucket function");
if (maxValue > minValue) {
if (exprValue < minValue) {
output.set(0);
} else if (exprValue >= maxValue) {
output.set(numBuckets + 1);
} else {
output.set((int) Math.floor((numBuckets * (exprValue - minValue) / (maxValue - minValue)) + 1));
}
} else {
if (exprValue > minValue) {
output.set(0);
} else if (exprValue <= maxValue) {
output.set(numBuckets + 1);
} else {
output.set((int) Math.floor((numBuckets * (minValue - exprValue) / (minValue - maxValue)) + 1));
}
}
return output;
}
private IntWritable evaluate(float exprValue, float minValue, float maxValue, int numBuckets) {
Preconditions.checkArgument(numBuckets > 0, "numBuckets in width_bucket function must be above 0");
Preconditions.checkArgument(maxValue != minValue, "maxValue cannot be equal to minValue in width_bucket function");
if (maxValue > minValue) {
if (exprValue < minValue) {
output.set(0);
} else if (exprValue >= maxValue) {
output.set(numBuckets + 1);
} else {
output.set((int) Math.floor((numBuckets * (exprValue - minValue) / (maxValue - minValue)) + 1));
}
} else {
if (exprValue > minValue) {
output.set(0);
} else if (exprValue <= maxValue) {
output.set(numBuckets + 1);
} else {
output.set((int) Math.floor((numBuckets * (minValue - exprValue) / (minValue - maxValue)) + 1));
}
}
return output;
}
private IntWritable evaluate(double exprValue, double minValue, double maxValue, int numBuckets) {
Preconditions.checkArgument(numBuckets > 0, "numBuckets in width_bucket function must be above 0");
Preconditions.checkArgument(maxValue != minValue, "maxValue cannot be equal to minValue in width_bucket function");
if (maxValue > minValue) {
if (exprValue < minValue) {
output.set(0);
} else if (exprValue >= maxValue) {
output.set(numBuckets + 1);
} else {
output.set((int) Math.floor((numBuckets * (exprValue - minValue) / (maxValue - minValue)) + 1));
}
} else {
if (exprValue > minValue) {
output.set(0);
} else if (exprValue <= maxValue) {
output.set(numBuckets + 1);
} else {
output.set((int) Math.floor((numBuckets * (minValue - exprValue) / (minValue - maxValue)) + 1));
}
}
return output;
}
private IntWritable evaluate(HiveDecimal exprValue, HiveDecimal minValue, HiveDecimal maxValue,
int numBuckets) {
Preconditions.checkArgument(numBuckets > 0, "numBuckets in width_bucket function must be above 0");
Preconditions.checkArgument(!maxValue.equals(minValue),
"maxValue cannot be equal to minValue in width_bucket function");
if (maxValue.compareTo(minValue) > 0) {
if (exprValue.compareTo(minValue) < 0) {
output.set(0);
} else if (exprValue.compareTo(maxValue) >= 0) {
output.set(numBuckets + 1);
} else {
output.set(HiveDecimal.create(numBuckets).multiply(exprValue.subtract(minValue)).divide(
maxValue.subtract(minValue)).add(HiveDecimal.ONE).intValue());
}
} else {
if (exprValue.compareTo(minValue) > 0) {
output.set(0);
} else if (exprValue.compareTo(maxValue) <= 0) {
output.set(numBuckets + 1);
} else {
output.set(HiveDecimal.create(numBuckets).multiply(minValue.subtract(exprValue)).divide(
minValue.subtract(maxValue)).add(HiveDecimal.ONE).intValue());
}
}
return output;
}
private Object evaluate(byte exprValue, byte minValue, byte maxValue, int numBuckets) {
Preconditions.checkArgument(numBuckets > 0, "numBuckets in width_bucket function must be above 0");
Preconditions.checkArgument(maxValue != minValue, "maxValue cannot be equal to minValue in width_bucket function");
if (maxValue > minValue) {
if (exprValue < minValue) {
output.set(0);
} else if (exprValue >= maxValue) {
output.set(numBuckets + 1);
} else {
output.set((int) Math.floor((numBuckets * (exprValue - minValue) / (maxValue - minValue)) + 1));
}
} else {
if (exprValue > minValue) {
output.set(0);
} else if (exprValue <= maxValue) {
output.set(numBuckets + 1);
} else {
output.set((int) Math.floor((numBuckets * (minValue - exprValue) / (minValue - maxValue)) + 1));
}
}
return output;
}
@Override
public String getDisplayString(String[] children) {
return getStandardDisplayString("width_bucket", children);
}
}