package brickhouse.udf.timeseries;
/**
* Copyright 2012 Klout, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
**/
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
import org.apache.log4j.Logger;
import java.util.HashMap;
import java.util.Map;
/**
* Multiply a vector of numbers times another vector
*/
@Description(
name = "vector_cross_product",
value = " Multiply a vector times another vector"
)
public class VectorCrossProductUDF extends GenericUDF {
private static final Logger LOG = Logger.getLogger(VectorCrossProductUDF.class);
private ListObjectInspector list1Inspector;
private ListObjectInspector list2Inspector;
private MapObjectInspector map1Inspector;
private MapObjectInspector map2Inspector;
private PrimitiveObjectInspector key1Inspector;
private PrimitiveObjectInspector key2Inspector;
private PrimitiveObjectInspector value1Inspector;
private PrimitiveObjectInspector value2Inspector;
private StandardListObjectInspector retListInspector;
private StandardMapObjectInspector retMapInspector;
public Object evaluateList(Object list1Obj, Object list2Obj) {
int len1 = list1Inspector.getListLength(list1Obj);
int len2 = list2Inspector.getListLength(list2Obj);
if (len1 != len2) {
LOG.warn("vector lengths do not match " + list1Obj + " :: " + list2Obj);
return null;
}
Object retList = retListInspector.create(0);
for (int i = 0; i < len1; ++i) {
Object list1Val = this.list1Inspector.getListElement(list1Obj, i);
double list1Dbl = NumericUtil.getNumericValue(value1Inspector, list1Val);
Object list2Val = this.list2Inspector.getListElement(list2Obj, i);
double list2Dbl = NumericUtil.getNumericValue(value2Inspector, list2Val);
double newVal = list1Dbl * list2Dbl;
retListInspector.set(retList, i, NumericUtil.castToPrimitiveNumeric(newVal,
((PrimitiveObjectInspector) retListInspector.getListElementObjectInspector()).getPrimitiveCategory()));
}
return retList;
}
public Object evaluateMap(Object uninspMapObj1, Object uninspMapObj2) {
/// A little tricky, because keys won't match if the ObjectInspectors aren't the
/// same .. If the ObjectInspectors are the same class, assume they can be compared
Object retMap = retMapInspector.create();
Map map1 = map1Inspector.getMap(uninspMapObj1);
Map map2 = map2Inspector.getMap(uninspMapObj2);
boolean simpleLookup = map1Inspector.getMapKeyObjectInspector().getClass().equals(
map2Inspector.getMapKeyObjectInspector());
Map stdKeyMap = new HashMap();
if (!simpleLookup) {
for (Object mapKey2 : map2.keySet()) {
Object stdKey2 = ObjectInspectorUtils.copyToStandardJavaObject(mapKey2,
map2Inspector.getMapKeyObjectInspector());
stdKeyMap.put(stdKey2, mapKey2);
}
}
for (Object mapKey1 : map1.keySet()) {
Object mapVal1Obj = map1.get(mapKey1);
double mapVal1Dbl = NumericUtil.getNumericValue(value1Inspector, mapVal1Obj);
Object stdKey1 = ObjectInspectorUtils.copyToStandardJavaObject(mapKey1,
map1Inspector.getMapKeyObjectInspector());
Object mapVal2Obj = null;
if (simpleLookup) {
mapVal2Obj = map2.get(mapKey1);
} else {
/// Need to do lookup in stdKeyMap
mapVal2Obj = map2.get(stdKeyMap.get(stdKey1));
}
if (mapVal2Obj != null) {
double mapVal2Dbl = NumericUtil.getNumericValue(value2Inspector, mapVal2Obj);
double newVal = mapVal1Dbl * mapVal2Dbl;
Object stdVal = NumericUtil.castToPrimitiveNumeric(newVal,
((PrimitiveObjectInspector) retMapInspector.getMapValueObjectInspector()).getPrimitiveCategory());
retMapInspector.put(retMap, stdKey1, stdVal);
}
}
return retMap;
}
@Override
public Object evaluate(DeferredObject[] arg0) throws HiveException {
if (list1Inspector != null) {
return evaluateList(arg0[0].get(), arg0[1].get());
} else {
return evaluateMap(arg0[0].get(), arg0[1].get());
}
}
@Override
public String getDisplayString(String[] arg0) {
return "vector_cross_product";
}
private void usage(String message) throws UDFArgumentException {
LOG.error("vector_cross_product: Multiply a vector times another vector : " + message);
throw new UDFArgumentException("vector_scalar_mult: Multiply a vector times another vector : " + message);
}
@Override
public ObjectInspector initialize(ObjectInspector[] arg0)
throws UDFArgumentException {
if (arg0.length != 2)
usage("Must have two arguments.");
if (arg0[0].getCategory() == Category.MAP) {
if (arg0[1].getCategory() != Category.MAP)
usage("Vectors need to be both maps");
this.map1Inspector = (MapObjectInspector) arg0[0];
this.map2Inspector = (MapObjectInspector) arg0[1];
if (map1Inspector.getMapKeyObjectInspector().getCategory() != Category.PRIMITIVE)
usage("First Vector map key must be a primitive");
this.key1Inspector = (PrimitiveObjectInspector) map1Inspector.getMapKeyObjectInspector();
if (map2Inspector.getMapKeyObjectInspector().getCategory() != Category.PRIMITIVE)
usage("Second Vector map key must be a primitive");
this.key2Inspector = (PrimitiveObjectInspector) map2Inspector.getMapKeyObjectInspector();
if (key2Inspector.getPrimitiveCategory() != key1Inspector.getPrimitiveCategory())
usage(" Map key types must match");
if (map1Inspector.getMapValueObjectInspector().getCategory() != Category.PRIMITIVE)
usage("First Vector map value must be a primitive");
this.value1Inspector = (PrimitiveObjectInspector) map1Inspector.getMapValueObjectInspector();
if (map2Inspector.getMapValueObjectInspector().getCategory() != Category.PRIMITIVE)
usage("Second Vector map value must be a primitive");
this.value2Inspector = (PrimitiveObjectInspector) map2Inspector.getMapValueObjectInspector();
} else if (arg0[0].getCategory() == Category.LIST) {
if (arg0[1].getCategory() != Category.LIST)
usage("Vectors need to be both arrays");
this.list1Inspector = (ListObjectInspector) arg0[0];
this.list2Inspector = (ListObjectInspector) arg0[1];
if (list1Inspector.getListElementObjectInspector().getCategory() != Category.PRIMITIVE)
usage("First Vector array value must be a primitive");
this.value1Inspector = (PrimitiveObjectInspector) list1Inspector.getListElementObjectInspector();
if (list2Inspector.getListElementObjectInspector().getCategory() != Category.PRIMITIVE)
usage("Second Vector array value must be a primitive");
this.value2Inspector = (PrimitiveObjectInspector) list2Inspector.getListElementObjectInspector();
} else {
usage("Arguments must be arrays or maps");
}
if (list1Inspector != null) {
retListInspector = ObjectInspectorFactory.getStandardListObjectInspector(
ObjectInspectorUtils.getStandardObjectInspector(value1Inspector,
ObjectInspectorUtils.ObjectInspectorCopyOption.JAVA));
return retListInspector;
} else {
retMapInspector = ObjectInspectorFactory.getStandardMapObjectInspector(
ObjectInspectorUtils.getStandardObjectInspector(map1Inspector.getMapKeyObjectInspector(),
ObjectInspectorUtils.ObjectInspectorCopyOption.JAVA),
ObjectInspectorUtils.getStandardObjectInspector(value1Inspector,
ObjectInspectorUtils.ObjectInspectorCopyOption.JAVA));
return retMapInspector;
}
}
}