/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.smile.tools; import hivemall.smile.ModelType; import hivemall.smile.classification.DecisionTree; import hivemall.smile.regression.RegressionTree; import hivemall.smile.vm.StackMachine; import hivemall.smile.vm.VMRuntimeException; import hivemall.utils.codec.Base91; import hivemall.utils.codec.DeflateCodec; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.io.IOUtils; import java.io.Closeable; import java.io.IOException; import java.util.Arrays; import javax.annotation.Nonnull; import javax.annotation.Nullable; import javax.script.Bindings; import javax.script.Compilable; import javax.script.CompiledScript; import javax.script.ScriptEngine; import javax.script.ScriptEngineManager; import javax.script.ScriptException; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.MapredContext; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.Writable; import org.apache.hadoop.mapred.JobConf; @Description( name = "tree_predict", value = "_FUNC_(string modelId, int modelType, string script, array<double> features [, const boolean classification])" + " - Returns a prediction result of a random forest") @UDFType(deterministic = true, stateful = false) public final class TreePredictUDF extends GenericUDF { private boolean classification; private PrimitiveObjectInspector modelTypeOI; private StringObjectInspector stringOI; private ListObjectInspector featureListOI; private PrimitiveObjectInspector featureElemOI; @Nullable private transient Evaluator evaluator; private boolean support_javascript_eval = true; @Override public void configure(MapredContext context) { super.configure(context); if (context != null) { JobConf conf = context.getJobConf(); String tdJarVersion = conf.get("td.jar.version"); if (tdJarVersion != null) { this.support_javascript_eval = false; } } } @Override public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { if (argOIs.length != 4 && argOIs.length != 5) { throw new UDFArgumentException("_FUNC_ takes 4 or 5 arguments"); } this.modelTypeOI = HiveUtils.asIntegerOI(argOIs[1]); this.stringOI = HiveUtils.asStringOI(argOIs[2]); ListObjectInspector listOI = HiveUtils.asListOI(argOIs[3]); this.featureListOI = listOI; ObjectInspector elemOI = listOI.getListElementObjectInspector(); this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI); boolean classification = false; if (argOIs.length == 5) { classification = HiveUtils.getConstBoolean(argOIs[4]); } this.classification = classification; if (classification) { return PrimitiveObjectInspectorFactory.writableIntObjectInspector; } else { return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; } } @Override public Writable evaluate(@Nonnull DeferredObject[] arguments) throws HiveException { Object arg0 = arguments[0].get(); if (arg0 == null) { throw new HiveException("ModelId was null"); } // Not using string OI for backward compatibilities String modelId = arg0.toString(); Object arg1 = arguments[1].get(); int modelTypeId = PrimitiveObjectInspectorUtils.getInt(arg1, modelTypeOI); ModelType modelType = ModelType.resolve(modelTypeId); Object arg2 = arguments[2].get(); if (arg2 == null) { return null; } Text script = stringOI.getPrimitiveWritableObject(arg2); Object arg3 = arguments[3].get(); if (arg3 == null) { throw new HiveException("array<double> features was null"); } double[] features = HiveUtils.asDoubleArray(arg3, featureListOI, featureElemOI); if (evaluator == null) { this.evaluator = getEvaluator(modelType, support_javascript_eval); } Writable result = evaluator.evaluate(modelId, modelType.isCompressed(), script, features, classification); return result; } @Nonnull private static Evaluator getEvaluator(@Nonnull ModelType type, boolean supportJavascriptEval) throws UDFArgumentException { final Evaluator evaluator; switch (type) { case serialization: case serialization_compressed: { evaluator = new JavaSerializationEvaluator(); break; } case opscode: case opscode_compressed: { evaluator = new StackmachineEvaluator(); break; } case javascript: case javascript_compressed: { if (!supportJavascriptEval) { throw new UDFArgumentException( "Javascript evaluation is not allowed in Treasure Data env"); } evaluator = new JavascriptEvaluator(); break; } default: throw new UDFArgumentException("Unexpected model type was detected: " + type); } return evaluator; } @Override public void close() throws IOException { this.modelTypeOI = null; this.stringOI = null; this.featureElemOI = null; this.featureListOI = null; IOUtils.closeQuietly(evaluator); this.evaluator = null; } @Override public String getDisplayString(String[] children) { return "tree_predict(" + Arrays.toString(children) + ")"; } public interface Evaluator extends Closeable { @Nullable Writable evaluate(@Nonnull String modelId, boolean compressed, @Nonnull final Text script, @Nonnull final double[] features, final boolean classification) throws HiveException; } static final class JavaSerializationEvaluator implements Evaluator { @Nullable private String prevModelId = null; private DecisionTree.Node cNode = null; private RegressionTree.Node rNode = null; JavaSerializationEvaluator() {} @Override public Writable evaluate(@Nonnull String modelId, boolean compressed, @Nonnull Text script, double[] features, boolean classification) throws HiveException { if (classification) { return evaluateClassification(modelId, compressed, script, features); } else { return evaluteRegression(modelId, compressed, script, features); } } private IntWritable evaluateClassification(@Nonnull String modelId, boolean compressed, @Nonnull Text script, double[] features) throws HiveException { if (!modelId.equals(prevModelId)) { this.prevModelId = modelId; int length = script.getLength(); byte[] b = script.getBytes(); b = Base91.decode(b, 0, length); this.cNode = DecisionTree.deserializeNode(b, b.length, compressed); } assert (cNode != null); int result = cNode.predict(features); return new IntWritable(result); } private DoubleWritable evaluteRegression(@Nonnull String modelId, boolean compressed, @Nonnull Text script, double[] features) throws HiveException { if (!modelId.equals(prevModelId)) { this.prevModelId = modelId; int length = script.getLength(); byte[] b = script.getBytes(); b = Base91.decode(b, 0, length); this.rNode = RegressionTree.deserializeNode(b, b.length, compressed); } assert (rNode != null); double result = rNode.predict(features); return new DoubleWritable(result); } @Override public void close() throws IOException {} } static final class StackmachineEvaluator implements Evaluator { private String prevModelId = null; private StackMachine prevVM = null; private DeflateCodec codec = null; StackmachineEvaluator() {} @Override public Writable evaluate(@Nonnull String modelId, boolean compressed, @Nonnull Text script, double[] features, boolean classification) throws HiveException { final String scriptStr; if (compressed) { if (codec == null) { this.codec = new DeflateCodec(false, true); } byte[] b = script.getBytes(); int len = script.getLength(); b = Base91.decode(b, 0, len); try { b = codec.decompress(b); } catch (IOException e) { throw new HiveException("decompression failed", e); } scriptStr = new String(b); } else { scriptStr = script.toString(); } final StackMachine vm; if (modelId.equals(prevModelId)) { vm = prevVM; } else { vm = new StackMachine(); try { vm.compile(scriptStr); } catch (VMRuntimeException e) { throw new HiveException("failed to compile StackMachine", e); } this.prevModelId = modelId; this.prevVM = vm; } try { vm.eval(features); } catch (VMRuntimeException vme) { throw new HiveException("failed to eval StackMachine", vme); } catch (Throwable e) { throw new HiveException("failed to eval StackMachine", e); } Double result = vm.getResult(); if (result == null) { return null; } if (classification) { return new IntWritable(result.intValue()); } else { return new DoubleWritable(result.doubleValue()); } } @Override public void close() throws IOException { IOUtils.closeQuietly(codec); } } static final class JavascriptEvaluator implements Evaluator { private final ScriptEngine scriptEngine; private final Compilable compilableEngine; private String prevModelId = null; private CompiledScript prevCompiled; private DeflateCodec codec = null; JavascriptEvaluator() throws UDFArgumentException { ScriptEngineManager manager = new ScriptEngineManager(); ScriptEngine engine = manager.getEngineByExtension("js"); if (!(engine instanceof Compilable)) { throw new UDFArgumentException("ScriptEngine was not compilable: " + engine.getFactory().getEngineName() + " version " + engine.getFactory().getEngineVersion()); } this.scriptEngine = engine; this.compilableEngine = (Compilable) engine; } @Override public Writable evaluate(@Nonnull String modelId, boolean compressed, @Nonnull Text script, double[] features, boolean classification) throws HiveException { final String scriptStr; if (compressed) { if (codec == null) { this.codec = new DeflateCodec(false, true); } byte[] b = script.getBytes(); int len = script.getLength(); b = Base91.decode(b, 0, len); try { b = codec.decompress(b); } catch (IOException e) { throw new HiveException("decompression failed", e); } scriptStr = new String(b); } else { scriptStr = script.toString(); } final CompiledScript compiled; if (modelId.equals(prevModelId)) { compiled = prevCompiled; } else { try { compiled = compilableEngine.compile(scriptStr); } catch (ScriptException e) { throw new HiveException("failed to compile: \n" + script, e); } this.prevCompiled = compiled; } final Bindings bindings = scriptEngine.createBindings(); final Object result; try { bindings.put("x", features); result = compiled.eval(bindings); } catch (ScriptException se) { throw new HiveException("failed to evaluate: \n" + script, se); } catch (Throwable e) { throw new HiveException("failed to evaluate: \n" + script, e); } finally { bindings.clear(); } if (result == null) { return null; } if (!(result instanceof Number)) { throw new HiveException("Got an unexpected non-number result: " + result); } if (classification) { Number casted = (Number) result; return new IntWritable(casted.intValue()); } else { Number casted = (Number) result; return new DoubleWritable(casted.doubleValue()); } } @Override public void close() throws IOException { IOUtils.closeQuietly(codec); } } }