/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.fm; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.lang.NumberUtils; import java.io.IOException; import java.util.Arrays; import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; import org.apache.hadoop.io.Text; @Description(name = "ffm_predict", value = "_FUNC_(string modelId, string model, array<string> features)" + " returns a prediction result in double from a Field-aware Factorization Machine") @UDFType(deterministic = true, stateful = false) public final class FFMPredictUDF extends GenericUDF { private StringObjectInspector _modelIdOI; private StringObjectInspector _modelOI; private ListObjectInspector _featureListOI; private DoubleWritable _result; @Nullable private String _cachedModeId; @Nullable private FFMPredictionModel _cachedModel; @Nullable private Feature[] _probes; public FFMPredictUDF() {} @Override public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { if (argOIs.length != 3) { throw new UDFArgumentException("_FUNC_ takes 3 arguments"); } this._modelIdOI = HiveUtils.asStringOI(argOIs[0]); this._modelOI = HiveUtils.asStringOI(argOIs[1]); this._featureListOI = HiveUtils.asListOI(argOIs[2]); this._result = new DoubleWritable(); return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; } @Override public Object evaluate(DeferredObject[] args) throws HiveException { String modelId = _modelIdOI.getPrimitiveJavaObject(args[0].get()); if (modelId == null) { throw new HiveException("modelId is not set"); } final FFMPredictionModel model; if (modelId.equals(_cachedModeId)) { model = this._cachedModel; } else { Text serModel = _modelOI.getPrimitiveWritableObject(args[1].get()); if (serModel == null) { throw new HiveException("Model is null for model ID: " + modelId); } byte[] b = serModel.getBytes(); final int length = serModel.getLength(); try { model = FFMPredictionModel.deserialize(b, length); b = null; } catch (ClassNotFoundException e) { throw new HiveException(e); } catch (IOException e) { throw new HiveException(e); } this._cachedModeId = modelId; this._cachedModel = model; } int numFeatures = model.getNumFeatures(); int numFields = model.getNumFields(); Object arg2 = args[2].get(); // [workaround] // java.lang.ClassCastException: org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray // cannot be cast to [Ljava.lang.Object; if (arg2 instanceof LazyBinaryArray) { arg2 = ((LazyBinaryArray) arg2).getList(); } Feature[] x = Feature.parseFFMFeatures(arg2, _featureListOI, _probes, numFeatures, numFields); if (x == null || x.length == 0) { return null; // return NULL if there are no features } this._probes = x; double predicted = predict(x, model); _result.set(predicted); return _result; } private static double predict(@Nonnull final Feature[] x, @Nonnull final FFMPredictionModel model) throws HiveException { // w0 double ret = model.getW0(); // W for (Feature e : x) { double xi = e.getValue(); float wi = model.getW(e); double wx = wi * xi; ret += wx; } // V final int factors = model.getNumFactors(); final float[] vij = new float[factors]; final float[] vji = new float[factors]; for (int i = 0; i < x.length; ++i) { final Feature ei = x[i]; final double xi = ei.getValue(); final int iField = ei.getField(); for (int j = i + 1; j < x.length; ++j) { final Feature ej = x[j]; final double xj = ej.getValue(); final int jField = ej.getField(); if (!model.getV(ei, jField, vij)) { continue; } if (!model.getV(ej, iField, vij)) { continue; } for (int f = 0; f < factors; f++) { float vijf = vij[f]; float vjif = vji[f]; ret += vijf * vjif * xi * xj; } } } if (!NumberUtils.isFinite(ret)) { throw new HiveException("Detected " + ret + " in ffm_predict"); } return ret; } @Override public void close() throws IOException { super.close(); // clean up to help GC this._cachedModel = null; this._probes = null; } @Override public String getDisplayString(String[] args) { return "ffm_predict(" + Arrays.toString(args) + ")"; } }