package com.facebook.hive.udf;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.IntWritable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@Description(name = "udfsample",
value = "_FUNC_(N, A) - Randomly samples (at most) N elements from array A.")
public class UDFSample extends GenericUDF {
private ObjectInspectorConverters.Converter int_converter;
private ListObjectInspector arrayOI;
@Override
public ObjectInspector initialize(ObjectInspector[] arguments)
throws UDFArgumentException {
if (arguments.length != 2) {
throw new UDFArgumentLengthException("SAMPLE expects two arguments.");
}
if (!arguments[0].getCategory().equals(Category.PRIMITIVE)) {
throw new UDFArgumentTypeException(0, "SAMPLE expects an INTEGER as its first argument");
}
int_converter = ObjectInspectorConverters.getConverter(arguments[0],
PrimitiveObjectInspectorFactory.writableIntObjectInspector);
if (!arguments[1].getCategory().equals(Category.LIST)) {
throw new UDFArgumentTypeException(1, "SAMPLE expects an ARRAY as its second argument");
}
arrayOI = (ListObjectInspector)arguments[1];
return arguments[1];
}
@Override
public Object evaluate(DeferredObject[] arguments) throws HiveException {
IntWritable intWritable = (IntWritable)int_converter.convert(arguments[0].get());
if (intWritable == null) {
return null;
}
int N = intWritable.get();
if (N < 0) {
throw new UDFArgumentException("SAMPLE requires a nonnegative number of elements to sample.");
}
List<?> array = arrayOI.getList(arguments[1].get());
if (array == null) {
return null;
}
if (N >= array.size()) {
return arguments[1].get();
}
ArrayList<Object> array_copy = new ArrayList<Object>(array);
Collections.shuffle(array_copy);
return array_copy.subList(0, N);
}
@Override
public String getDisplayString(String[] children) {
assert (children.length == 2);
StringBuilder sb = new StringBuilder();
sb.append("fb_sample(")
.append(children[0])
.append(", ")
.append(children[1])
.append(")");
return sb.toString();
}
}