package brickhouse.udf.collect;
/**
* Copyright 2012 Klout, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
**/
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
/**
* UDF for the set difference of two arrays or maps.
*/
@Description(name = "set_diff",
value = "_FUNC_(a,b) - Returns a list of those items in a, but not in b "
)
public class SetDifferenceUDF extends GenericUDF {
private Category category;
private ListObjectInspector list1Inspector;
private ListObjectInspector list2Inspector;
private MapObjectInspector map1Inspector;
private MapObjectInspector map2Inspector;
private PrimitiveObjectInspector prim1Inspector;
private PrimitiveObjectInspector prim2Inspector;
private StandardListObjectInspector stdListInspector;
private StandardMapObjectInspector stdMapInspector;
public List evaluate(List l1, List l2) {
if (l1 == null) {
return null;
}
//// Use a HashSet to avoid linear lookups , for large lists
HashSet negSet = new HashSet();
if (l2 != null) {
for (Object lObj : l2) {
Object inspObj = prim2Inspector.getPrimitiveJavaObject(lObj);
negSet.add(inspObj);
}
}
List newList = (List) stdListInspector.create(0);
for (Object obj : l1) {
Object inspObj = prim1Inspector.getPrimitiveJavaObject(obj);
if (!negSet.contains(inspObj)) {
newList.add(inspObj);
}
}
return newList;
}
public Map evaluate(Map m1, Map m2) {
Map newMap = (Map) stdMapInspector.create();
if (m1 == null) {
return null;
}
HashSet negSet = new HashSet();
if (m2 != null) {
for (Object mObj : m2.keySet()) {
Object inspObj = prim2Inspector.getPrimitiveJavaObject(mObj);
negSet.add(inspObj);
}
}
if (m1.size() > 0)
for (Object k : m1.keySet()) {
Object inspObj = prim1Inspector.getPrimitiveJavaObject(k);
if (!negSet.contains(inspObj)) {
Object valObj = m1.get(k);
Object stdVal = ObjectInspectorUtils.copyToStandardObject(valObj, map1Inspector.getMapValueObjectInspector());
newMap.put(inspObj, stdVal);
}
}
return newMap;
}
@Override
public Object evaluate(DeferredObject[] args) throws HiveException {
if (category == Category.LIST) {
List theList1 = list1Inspector.getList(args[0].get());
List theList2 = list2Inspector.getList(args[1].get());
List retList = evaluate(theList1, theList2);
return retList;
} else if (category == Category.MAP) {
Map theMap1 = map1Inspector.getMap(args[0].get());
Map theMap2 = map2Inspector.getMap(args[1].get());
Map retMap = evaluate(theMap1, theMap2);
return retMap;
} else {
throw new HiveException(" Only maps or lists are supported ");
}
}
@Override
public String getDisplayString(String[] args) {
StringBuilder sb = new StringBuilder("set_diff( ");
for (int i = 0; i < args.length - 1; ++i) {
sb.append(args[i]);
sb.append(",");
}
sb.append(args[args.length - 1]);
sb.append(")");
return sb.toString();
}
@Override
public ObjectInspector initialize(ObjectInspector[] args)
throws UDFArgumentException {
if (args.length != 2) {
throw new UDFArgumentException("Usage: set_diff takes 2 maps or lists, and returns the difference");
}
ObjectInspector first = args[0];
ObjectInspector second = args[1];
if (first.getCategory() == Category.LIST
&& second.getCategory() == Category.LIST) {
category = first.getCategory();
list1Inspector = (ListObjectInspector) first;
list2Inspector = (ListObjectInspector) second;
if (list1Inspector.getListElementObjectInspector().getCategory() != Category.PRIMITIVE
|| list2Inspector.getListElementObjectInspector().getCategory() != Category.PRIMITIVE) {
throw new UDFArgumentException(" set_diff only takes maps or lists of primitives.");
}
prim1Inspector = (PrimitiveObjectInspector) list1Inspector.getListElementObjectInspector();
prim2Inspector = (PrimitiveObjectInspector) list2Inspector.getListElementObjectInspector();
if (prim1Inspector.getPrimitiveCategory() != prim2Inspector.getPrimitiveCategory()) {
throw new UDFArgumentException(" set_diff takes only lists of the same primitive type.");
}
stdListInspector = ObjectInspectorFactory.getStandardListObjectInspector(
ObjectInspectorUtils.getStandardObjectInspector(prim1Inspector, ObjectInspectorUtils.ObjectInspectorCopyOption.JAVA));
return stdListInspector;
} else if (first.getCategory() == Category.MAP
&& second.getCategory() == Category.MAP) {
category = first.getCategory();
map1Inspector = (MapObjectInspector) first;
map2Inspector = (MapObjectInspector) second;
if (map1Inspector.getMapKeyObjectInspector().getCategory() != Category.PRIMITIVE
|| map2Inspector.getMapKeyObjectInspector().getCategory() != Category.PRIMITIVE) {
throw new UDFArgumentException(" set_diff only takes maps or lists of primitives.");
}
prim1Inspector = (PrimitiveObjectInspector) map1Inspector.getMapKeyObjectInspector();
prim2Inspector = (PrimitiveObjectInspector) map2Inspector.getMapKeyObjectInspector();
if (prim1Inspector.getPrimitiveCategory() != prim2Inspector.getPrimitiveCategory()) {
throw new UDFArgumentException(" set_diff takes only maps of the same primitive type.");
}
stdMapInspector = ObjectInspectorFactory.getStandardMapObjectInspector(
ObjectInspectorUtils.getStandardObjectInspector(prim1Inspector, ObjectInspectorUtils.ObjectInspectorCopyOption.JAVA),
ObjectInspectorUtils.getStandardObjectInspector(map1Inspector.getMapValueObjectInspector())
);
return stdMapInspector;
} else {
throw new UDFArgumentException(" set_diff only takes maps or lists.");
}
}
}