/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.tools.map; import java.util.Arrays; import java.util.Map; import java.util.TreeMap; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector; @Description(name = "map_tail_n", value = "_FUNC_(map SRC, int N) - Returns the last N elements " + "from a sorted array of SRC") @UDFType(deterministic = true, stateful = false) public class MapTailNUDF extends GenericUDF { private MapObjectInspector mapObjectInspector; private IntObjectInspector intObjectInspector; @Override public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { if (arguments.length != 2) { throw new UDFArgumentLengthException( "map_tail_n only takes 2 arguments: map<object, object>, int"); } if (!(arguments[0] instanceof MapObjectInspector)) { throw new UDFArgumentException("The first argument must be a map"); } this.mapObjectInspector = (MapObjectInspector) arguments[0]; if (!(arguments[1] instanceof IntObjectInspector)) { throw new UDFArgumentException("The second argument must be an int"); } this.intObjectInspector = (IntObjectInspector) arguments[1]; ObjectInspector keyOI = ObjectInspectorUtils.getStandardObjectInspector(mapObjectInspector.getMapKeyObjectInspector()); ObjectInspector valueOI = mapObjectInspector.getMapValueObjectInspector(); return ObjectInspectorFactory.getStandardMapObjectInspector(keyOI, valueOI); } @Override public Map<?, ?> evaluate(DeferredObject[] arguments) throws HiveException { Object mapObj = arguments[0].get(); Map<?, ?> map = this.mapObjectInspector.getMap(mapObj); int n = this.intObjectInspector.get(arguments[1].get()); Map<?, ?> ret = tailN(map, n); return ret; } @Override public String getDisplayString(String[] arguments) { return "map_tail_n( " + Arrays.toString(arguments) + " )"; } private Map<Object, Object> tailN(Map<?, ?> m, int n) { final ObjectInspector keyInspector = mapObjectInspector.getMapKeyObjectInspector(); final TreeMap<Object, Object> tail = new TreeMap<Object, Object>(); for (Map.Entry<?, ?> e : m.entrySet()) { Object k = ObjectInspectorUtils.copyToStandardObject(e.getKey(), keyInspector); Object v = e.getValue(); tail.put(k, v); } if (tail.size() <= n) { return tail; } TreeMap<Object, Object> ret = new TreeMap<Object, Object>(); for (int i = 0; i < n; i++) { Object k = tail.lastKey(); Object v = tail.remove(k); ret.put(k, v); } return ret; } }