/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.xgboost; import java.io.ByteArrayInputStream; import java.util.*; import java.util.Map.Entry; import java.util.List; import ml.dmlc.xgboost4j.LabeledPoint; import ml.dmlc.xgboost4j.java.Booster; import ml.dmlc.xgboost4j.java.DMatrix; import ml.dmlc.xgboost4j.java.XGBoost; import ml.dmlc.xgboost4j.java.XGBoostError; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.objectinspector.*; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import hivemall.UDTFWithOptions; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.lang.Primitives; public abstract class XGBoostPredictUDTF extends UDTFWithOptions { // For input parameters private PrimitiveObjectInspector rowIdOI; private ListObjectInspector featureListOI; private PrimitiveObjectInspector featureElemOI; private PrimitiveObjectInspector modelIdOI; private PrimitiveObjectInspector modelOI; // For input buffer private Map<String, Booster> mapToModel; private Map<String, List<LabeledPointWithRowId>> rowBuffer; private int batch_size; // Settings for the XGBoost native library static { NativeLibLoader.initXGBoost(); } public XGBoostPredictUDTF() {} protected final class LabeledPointWithRowId { public String rowId; public LabeledPoint point; // Prevent other classes from instantiating this LabeledPointWithRowId() {} } private LabeledPointWithRowId createLabeledPoint(String rowId, LabeledPoint point) { final LabeledPointWithRowId p = new LabeledPointWithRowId(); p.rowId = rowId; p.point = point; return p; } @Override protected Options getOptions() { Options opts = new Options(); opts.addOption("batch_size", true, "Number of rows to predict together [default: 128]"); return opts; } @Override protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { int _batch_size = 128; CommandLine cl = null; if(argOIs.length >= 5) { String rawArgs = HiveUtils.getConstString(argOIs[4]); cl = this.parseOptions(rawArgs); _batch_size = Primitives.parseInt(cl.getOptionValue("_batch_size"), _batch_size); if(_batch_size < 1) { throw new IllegalArgumentException( "batch_size must be greater than 0: " + _batch_size); } } this.batch_size = _batch_size; return cl; } /** Override this to output predicted results depending on a taks type */ abstract public StructObjectInspector getReturnOI(); abstract public void forwardPredicted( final List<LabeledPointWithRowId> testData, final float[][] predicted) throws HiveException; @Override public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { if(argOIs.length != 4 && argOIs.length != 5) { throw new UDFArgumentException(this.getClass().getSimpleName() + " takes 4 or 5 arguments: string rowid, string[] features, string model_id," + " array<byte> pred_model [, string options]: " + argOIs.length); } else { this.processOptions(argOIs); this.rowIdOI = HiveUtils.asStringOI(argOIs[0]); final ListObjectInspector listOI = HiveUtils.asListOI(argOIs[1]); final ObjectInspector elemOI = listOI.getListElementObjectInspector(); this.featureListOI = listOI; this.featureElemOI = HiveUtils.asStringOI(elemOI); this.modelIdOI = HiveUtils.asStringOI(argOIs[2]); this.modelOI = HiveUtils.asBinaryOI(argOIs[3]); this.mapToModel = new HashMap<String, Booster>(); this.rowBuffer = new HashMap<String, List<LabeledPointWithRowId>>(); return getReturnOI(); } } private static DMatrix createDMatrix(final List<LabeledPointWithRowId> data) throws XGBoostError { final List<LabeledPoint> points = new ArrayList(data.size()); for(LabeledPointWithRowId d : data) { points.add(d.point); } return new DMatrix(points.iterator(), ""); } private static Booster initXgBooster(final byte[] input) throws HiveException { try { return XGBoost.loadModel(new ByteArrayInputStream(input)); } catch (Exception e) { throw new HiveException(e.getMessage()); } } private void predictAndFlush(final Booster model, final List<LabeledPointWithRowId> buf) throws HiveException { try { final DMatrix testData = createDMatrix(buf); final float[][] predicted = model.predict(testData); forwardPredicted(buf, predicted); } catch (Exception e) { throw new HiveException(e.getMessage()); } buf.clear(); } @Override public void process(Object[] args) throws HiveException { if(args[1] != null) { final String rowId = PrimitiveObjectInspectorUtils.getString(args[0], rowIdOI); final List<String> features = (List<String>) featureListOI.getList(args[1]); final String modelId = PrimitiveObjectInspectorUtils.getString(args[2], modelIdOI); if(!mapToModel.containsKey(modelId)) { final byte[] predModel = PrimitiveObjectInspectorUtils.getBinary(args[3], modelOI).getBytes(); mapToModel.put(modelId, initXgBooster(predModel)); } final LabeledPoint point = XGBoostUtils.parseFeatures(0.f, features); if(point != null) { if(!rowBuffer.containsKey(modelId)) { rowBuffer.put(modelId, new ArrayList()); } final List<LabeledPointWithRowId> buf = rowBuffer.get(modelId); buf.add(createLabeledPoint(rowId, point)); if(buf.size() >= batch_size) { predictAndFlush(mapToModel.get(modelId), buf); } } } } @Override public void close() throws HiveException { for(Entry<String, List<LabeledPointWithRowId>> e : rowBuffer.entrySet()) { predictAndFlush(mapToModel.get(e.getKey()), e.getValue()); } } }