/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hive.ql.udf.generic;
import java.util.HashSet;
import java.util.Set;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ReturnObjectInspectorResolver;
import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.BooleanWritable;
import com.esotericsoftware.minlog.Log;
/**
* GenericUDFIn
*
* Example usage:
* SELECT key FROM src WHERE key IN ("238", "1");
*
* From MySQL page on IN(): To comply with the SQL standard, IN returns NULL
* not only if the expression on the left hand side is NULL, but also if no
* match is found in the list and one of the expressions in the list is NULL.
*
* Also noteworthy: type conversion behavior is different from MySQL. With
* expr IN expr1, expr2... in MySQL, exprN will each be converted into the same
* type as expr. In the Hive implementation, all expr(N) will be converted into
* a common type for conversion consistency with other UDF's, and to prevent
* conversions from a big type to a small type (e.g. int to tinyint)
*/
@Description(name = "in",
value = "test _FUNC_(val1, val2...) - returns true if test equals any valN ")
public class GenericUDFIn extends GenericUDF {
private transient ObjectInspector[] argumentOIs;
// this set is a copy of the arguments objects - avoid serializing
private transient Set<Object> constantInSet;
private boolean isInSetConstant = true; //are variables from IN(...) constant
private final BooleanWritable bw = new BooleanWritable();
private transient ReturnObjectInspectorResolver conversionHelper;
private transient ObjectInspector compareOI;
@Override
public ObjectInspector initialize(ObjectInspector[] arguments)
throws UDFArgumentException {
if (arguments.length < 2) {
throw new UDFArgumentLengthException(
"The function IN requires at least two arguments, got "
+ arguments.length);
}
argumentOIs = arguments;
// We want to use the ReturnObjectInspectorResolver because otherwise
// ObjectInspectorUtils.compare() will return != for two objects that have
// different object inspectors, e.g. 238 and "238". The ROIR will help convert
// both values to a common type so that they can be compared reasonably.
conversionHelper = new GenericUDFUtils.ReturnObjectInspectorResolver(true);
for (ObjectInspector oi : arguments) {
if(!conversionHelper.update(oi)) {
StringBuilder sb = new StringBuilder();
sb.append("The arguments for IN should be the same type! Types are: {");
sb.append(arguments[0].getTypeName());
sb.append(" IN (");
for(int i=1; i<arguments.length; i++) {
if (i != 1) {
sb.append(", ");
}
sb.append(arguments[i].getTypeName());
}
sb.append(")}");
throw new UDFArgumentException(sb.toString());
}
}
compareOI = conversionHelper.get();
checkIfInSetConstant();
return PrimitiveObjectInspectorFactory.writableBooleanObjectInspector;
}
private void checkIfInSetConstant(){
for (int i = 1; i < argumentOIs.length; ++i){
if (!(argumentOIs[i] instanceof ConstantObjectInspector)){
isInSetConstant = false;
return;
}
}
}
// we start at index 1, since at 0 is the variable from table column
// (and those from IN(...) follow it)
private void prepareInSet(DeferredObject[] arguments) throws HiveException {
constantInSet = new HashSet<Object>();
if (compareOI.getCategory().equals(ObjectInspector.Category.PRIMITIVE)) {
for (int i = 1; i < arguments.length; ++i) {
constantInSet.add(((PrimitiveObjectInspector) compareOI)
.getPrimitiveJavaObject(conversionHelper
.convertIfNecessary(arguments[i].get(), argumentOIs[i])));
}
} else {
for (int i = 1; i < arguments.length; ++i) {
constantInSet.add(((ConstantObjectInspector) argumentOIs[i]).getWritableConstantValue());
}
}
}
@Override
public Object evaluate(DeferredObject[] arguments) throws HiveException {
bw.set(false);
if (arguments[0].get() == null) {
return null;
}
if (isInSetConstant) {
if (constantInSet == null) {
prepareInSet(arguments);
}
switch (compareOI.getCategory()) {
case PRIMITIVE: {
if (constantInSet.contains(((PrimitiveObjectInspector) compareOI)
.getPrimitiveJavaObject(conversionHelper.convertIfNecessary(arguments[0].get(),
argumentOIs[0])))) {
bw.set(true);
return bw;
}
break;
}
case LIST: {
if (constantInSet.contains(((ListObjectInspector) compareOI).getList(conversionHelper
.convertIfNecessary(arguments[0].get(), argumentOIs[0])))) {
bw.set(true);
return bw;
}
break;
}
case MAP: {
if (constantInSet.contains(((MapObjectInspector) compareOI).getMap(conversionHelper
.convertIfNecessary(arguments[0].get(), argumentOIs[0])))) {
bw.set(true);
return bw;
}
break;
}
case STRUCT: {
if (constantInSet.contains(((StructObjectInspector) compareOI).getStructFieldsDataAsList(conversionHelper
.convertIfNecessary(arguments[0].get(), argumentOIs[0])))) {
bw.set(true);
return bw;
}
break;
}
default:
throw new RuntimeException("Compare of unsupported constant type: "
+ compareOI.getCategory());
}
if (constantInSet.contains(null)) {
return null;
}
} else {
for (int i = 1; i < arguments.length; i++) {
if (ObjectInspectorUtils.compare(
conversionHelper.convertIfNecessary(
arguments[0].get(), argumentOIs[0]), compareOI,
conversionHelper.convertIfNecessary(
arguments[i].get(), argumentOIs[i], false), compareOI) == 0) {
bw.set(true);
return bw;
}
}
// Nothing matched. See comment at top.
for (int i = 1; i < arguments.length; i++) {
if (arguments[i].get() == null) {
return null;
}
}
}
return bw;
}
@Override
public String getDisplayString(String[] children) {
assert (children.length >= 2);
StringBuilder sb = new StringBuilder();
sb.append("(");
sb.append(children[0]);
sb.append(") ");
sb.append("IN (");
for(int i=1; i<children.length; i++) {
sb.append(children[i]);
if (i+1 != children.length) {
sb.append(", ");
}
}
sb.append(")");
return sb.toString();
}
}