/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.tools.mapred; import hivemall.ftvec.ExtractFeatureUDF; import hivemall.utils.collections.OpenHashMap; import hivemall.utils.hadoop.HadoopUtils; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.io.IOUtils; import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.SerDeException; import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; import org.apache.hadoop.io.Text; @Description( name = "distcache_gets", value = "_FUNC_(filepath, key, default_value [, parseKey]) - Returns map<key_type, value_type>|value_type") @UDFType(deterministic = false, stateful = false) public final class DistributedCacheLookupUDF extends GenericUDF { private boolean multipleKeyLookup; private boolean multipleDefaultValues; private boolean parseKey; private Object defaultValue; private PrimitiveObjectInspector keyInputOI; private PrimitiveObjectInspector valueInputOI; private ListObjectInspector keysInputOI; private ListObjectInspector valuesInputOI; private OpenHashMap<Object, Object> cache; @Override public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { if (argOIs.length != 3 && argOIs.length != 4) { throw new UDFArgumentException( "Invalid number of arguments for distcache_gets(FILEPATH, KEYS, DEFAULT_VAL, PARSE_KEY): " + argOIs.length + getUsage()); } if (!ObjectInspectorUtils.isConstantObjectInspector(argOIs[2])) { throw new UDFArgumentException( "Third argument DEFAULT_VALUE must be a constant value: " + TypeInfoUtils.getTypeInfoFromObjectInspector(argOIs[2])); } if (argOIs.length == 4) { this.parseKey = HiveUtils.getConstBoolean(argOIs[3]); } else { this.parseKey = false; } String filepath = HiveUtils.getConstString(argOIs[0]); ObjectInspector argOI2 = argOIs[2]; this.multipleDefaultValues = (argOI2.getCategory() == Category.LIST); if (multipleDefaultValues) { this.valuesInputOI = (ListObjectInspector) argOI2; ObjectInspector valuesElemOI = valuesInputOI.getListElementObjectInspector(); valueInputOI = HiveUtils.asPrimitiveObjectInspector(valuesElemOI); } else { this.defaultValue = HiveUtils.getConstValue(argOI2); valueInputOI = HiveUtils.asPrimitiveObjectInspector(argOI2); } ObjectInspector valueOutputOI = ObjectInspectorUtils.getStandardObjectInspector( valueInputOI, ObjectInspectorCopyOption.WRITABLE); final ObjectInspector outputOI; switch (argOIs[1].getCategory()) { case PRIMITIVE: this.multipleKeyLookup = false; this.keyInputOI = (PrimitiveObjectInspector) argOIs[1]; outputOI = valueOutputOI; break; case LIST: this.multipleKeyLookup = true; this.keysInputOI = (ListObjectInspector) argOIs[1]; ObjectInspector keysElemOI = keysInputOI.getListElementObjectInspector(); this.keyInputOI = HiveUtils.asPrimitiveObjectInspector(keysElemOI); outputOI = ObjectInspectorFactory.getStandardMapObjectInspector(keyInputOI, valueOutputOI); break; default: throw new UDFArgumentException("Unexpected key type: " + argOIs[1].getTypeName()); } if (parseKey && !HiveUtils.isStringOI(keyInputOI)) { throw new UDFArgumentException( "parseKey=true is only available for string typed key(s)"); } final OpenHashMap<Object, Object> map = new OpenHashMap<Object, Object>(8192); try { loadValues(map, new File(filepath), keyInputOI, valueInputOI); this.cache = map; } catch (IOException e) { throw new RuntimeException(e); } catch (SerDeException e) { throw new RuntimeException(e); } return outputOI; } private static void loadValues(OpenHashMap<Object, Object> map, File file, PrimitiveObjectInspector keyOI, PrimitiveObjectInspector valueOI) throws IOException, SerDeException { if (!file.exists()) { return; } if (!file.getName().endsWith(".crc")) { if (file.isDirectory()) { for (File f : file.listFiles()) { loadValues(map, f, keyOI, valueOI); } } else { LazySimpleSerDe serde = HiveUtils.getKeyValueLineSerde(keyOI, valueOI); StructObjectInspector lineOI = (StructObjectInspector) serde.getObjectInspector(); StructField keyRef = lineOI.getStructFieldRef("key"); StructField valueRef = lineOI.getStructFieldRef("value"); PrimitiveObjectInspector keyRefOI = (PrimitiveObjectInspector) keyRef.getFieldObjectInspector(); PrimitiveObjectInspector valueRefOI = (PrimitiveObjectInspector) valueRef.getFieldObjectInspector(); BufferedReader reader = null; try { reader = HadoopUtils.getBufferedReader(file); String line; while ((line = reader.readLine()) != null) { Text lineText = new Text(line); Object lineObj = serde.deserialize(lineText); List<Object> fields = lineOI.getStructFieldsDataAsList(lineObj); Object f0 = fields.get(0); Object f1 = fields.get(1); Object k = keyRefOI.getPrimitiveJavaObject(f0); Object v = valueRefOI.getPrimitiveWritableObject(valueRefOI.copyObject(f1)); map.put(k, v); } } finally { IOUtils.closeQuietly(reader); } } } } @Override public Object evaluate(DeferredObject[] args) throws HiveException { final Object arg1 = args[1].get(); if (multipleKeyLookup) { if (multipleDefaultValues) { Object arg2 = args[2].get(); return gets(arg1, arg2); } else { return gets(arg1); } } else { return get(arg1); } } private Object get(Object arg) { Object key = keyInputOI.getPrimitiveJavaObject(arg); Object value = cache.get(lookupKey(key)); return (value == null) ? defaultValue : value; } private Map<Object, Object> gets(Object arg) { List<?> keys = keysInputOI.getList(arg); final Map<Object, Object> map = new HashMap<Object, Object>(); for (Object k : keys) { if (k == null) { continue; } Object kj = keyInputOI.getPrimitiveJavaObject(k); final Object v = cache.get(lookupKey(kj)); if (v == null) { map.put(k, defaultValue); } else { map.put(k, v); } } return map; } private Map<Object, Object> gets(Object argKeys, Object argValues) throws HiveException { final List<?> keys = keysInputOI.getList(argKeys); final List<?> defaultValues = valuesInputOI.getList(argValues); final int numKeys = keys.size(); if (numKeys != defaultValues.size()) { throw new HiveException("# of default values != # of lookup keys: keys " + argKeys + ", values: " + argValues); } final Map<Object, Object> map = new HashMap<Object, Object>(); for (int i = 0; i < numKeys; i++) { Object k = keys.get(i); if (k == null) { continue; } Object kj = keyInputOI.getPrimitiveJavaObject(k); Object v = cache.get(lookupKey(kj)); if (v == null) { v = defaultValues.get(i); if (v != null) { v = valueInputOI.getPrimitiveWritableObject(valueInputOI.copyObject(v)); } } map.put(k, v); } return map; } private Object lookupKey(final Object key) { if (parseKey) { String keyStr = key.toString(); return ExtractFeatureUDF.extractFeature(keyStr); } else { return key; } } @Override public String getDisplayString(String[] args) { return "distcache_gets()"; } private static String getUsage() { return "\nUSAGE: " + "\n\tdistcache_gets(const string FILEPATH, object[] keys, const object defaultValue [, const boolean parseKey])::map<key_type, value_type>" + "\n\tdistcache_gets(const string FILEPATH, object key, const object defaultValue [, const boolean parseKey])::value_type" + "\n\tdistcache_gets(const string FILEPATH, object[] key, object[] defaultValues [, const boolean parseKey])::map<key_type, value_type>"; } }