/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.ftvec.binning; import hivemall.utils.hadoop.HiveUtils; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; import java.util.*; @Description( name = "feature_binning", value = "_FUNC_(array<features::string> features, const map<string, array<number>> quantiles_map)" + " / _FUNC(number weight, const array<number> quantiles)" + " - Returns binned features as an array<features::string> / bin ID as int") @UDFType(deterministic = true, stateful = false) public final class FeatureBinningUDF extends GenericUDF { private boolean multiple = true; private ListObjectInspector featuresOI; private StringObjectInspector featureOI; private MapObjectInspector quantilesMapOI; private StringObjectInspector keyOI; private ListObjectInspector quantilesOI; private PrimitiveObjectInspector quantileOI; private PrimitiveObjectInspector weightOI; private Map<Text, double[]> quantilesMap = null; private double[] quantiles = null; @Override public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException { if (OIs.length != 2) { throw new UDFArgumentLengthException("Specify two arguments"); } if (HiveUtils.isListOI(OIs[0]) && HiveUtils.isMapOI(OIs[1])) { // for (array<features::string> features, const map<string, array<number>> quantiles_map) if (!HiveUtils.isStringOI(((ListObjectInspector) OIs[0]).getListElementObjectInspector())) { throw new UDFArgumentTypeException(0, "Only array<string> type argument is acceptable but " + OIs[0].getTypeName() + " was passed as `features`"); } featuresOI = HiveUtils.asListOI(OIs[0]); featureOI = HiveUtils.asStringOI(featuresOI.getListElementObjectInspector()); quantilesMapOI = HiveUtils.asMapOI(OIs[1]); if (!HiveUtils.isStringOI(quantilesMapOI.getMapKeyObjectInspector()) || !HiveUtils.isListOI(quantilesMapOI.getMapValueObjectInspector()) || !HiveUtils.isNumberOI(((ListObjectInspector) quantilesMapOI.getMapValueObjectInspector()).getListElementObjectInspector())) { throw new UDFArgumentTypeException(1, "Only map<string, array<number>> type argument is acceptable but " + OIs[1].getTypeName() + " was passed as `quantiles_map`"); } keyOI = HiveUtils.asStringOI(quantilesMapOI.getMapKeyObjectInspector()); quantilesOI = HiveUtils.asListOI(quantilesMapOI.getMapValueObjectInspector()); quantileOI = HiveUtils.asDoubleCompatibleOI(quantilesOI.getListElementObjectInspector()); multiple = true; return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector); } else if (HiveUtils.isPrimitiveOI(OIs[0]) && HiveUtils.isListOI(OIs[1])) { // for (number weight, const array<number> quantiles) weightOI = HiveUtils.asDoubleCompatibleOI(OIs[0]); quantilesOI = HiveUtils.asListOI(OIs[1]); if (!HiveUtils.isNumberOI(quantilesOI.getListElementObjectInspector())) { throw new UDFArgumentTypeException(1, "Only array<number> type argument is acceptable but " + OIs[1].getTypeName() + " was passed as `quantiles`"); } quantileOI = HiveUtils.asDoubleCompatibleOI(quantilesOI.getListElementObjectInspector()); multiple = false; return PrimitiveObjectInspectorFactory.writableIntObjectInspector; } else { throw new UDFArgumentTypeException(0, "Only <array<features::string>, map<string, array<number>>> " + "or <number, array<number>> type arguments are accepted but <" + OIs[0].getTypeName() + ", " + OIs[1].getTypeName() + "> was passed."); } } @Override public Object evaluate(DeferredObject[] dObj) throws HiveException { if (multiple) { // init quantilesMap if (quantilesMap == null) { quantilesMap = new HashMap<Text, double[]>(); final Map<?, ?> _quantilesMap = quantilesMapOI.getMap(dObj[1].get()); for (Object _key : _quantilesMap.keySet()) { final Text key = new Text(keyOI.getPrimitiveJavaObject(_key)); final double[] val = HiveUtils.asDoubleArray(_quantilesMap.get(key), quantilesOI, quantileOI); quantilesMap.put(key, val); } } final List<?> fs = featuresOI.getList(dObj[0].get()); final List<Text> result = new ArrayList<Text>(); for (Object f : fs) { final String entry = featureOI.getPrimitiveJavaObject(f); final int pos = entry.indexOf(":"); if (pos < 0) { // categorical result.add(new Text(entry)); } else { // quantitative final Text key = new Text(entry.substring(0, pos)); String val = entry.substring(pos + 1); // binning if (quantilesMap.containsKey(key)) { val = String.valueOf(findBin(quantilesMap.get(key), Double.parseDouble(val))); } result.add(new Text(key + ":" + val)); } } return result; } else { // init quantiles if (quantiles == null) { quantiles = HiveUtils.asDoubleArray(dObj[1].get(), quantilesOI, quantileOI); } return new IntWritable(findBin(quantiles, PrimitiveObjectInspectorUtils.getDouble(dObj[0].get(), weightOI))); } } private int findBin(double[] _quantiles, double d) throws HiveException { if (_quantiles.length < 3) { throw new HiveException( "Length of `quantiles` should be greater than or equal to three but " + _quantiles.length + "."); } int res = Arrays.binarySearch(_quantiles, d); return (res < 0) ? ~res - 1 : (res == 0) ? 0 : res - 1; } @Override public String getDisplayString(String[] children) { final StringBuilder sb = new StringBuilder(); sb.append("feature_binning"); sb.append("("); if (children.length > 0) { sb.append(children[0]); for (int i = 1; i < children.length; i++) { sb.append(", "); sb.append(children[i]); } } sb.append(")"); return sb.toString(); } }