/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.smile.regression; import hivemall.UDTFWithOptions; import hivemall.smile.ModelType; import hivemall.smile.data.Attribute; import hivemall.smile.utils.SmileExtUtils; import hivemall.smile.utils.SmileTaskExecutor; import hivemall.smile.vm.StackMachine; import hivemall.utils.codec.Base91; import hivemall.utils.codec.DeflateCodec; import hivemall.utils.collections.DoubleArrayList; import hivemall.utils.datetime.StopWatch; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.hadoop.WritableUtils; import hivemall.utils.io.IOUtils; import hivemall.utils.lang.Primitives; import hivemall.utils.lang.RandomUtils; import java.io.IOException; import java.util.ArrayList; import java.util.BitSet; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.MapredContext; import org.apache.hadoop.hive.ql.exec.MapredContextAccessor; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapred.Counters.Counter; import org.apache.hadoop.mapred.Reporter; @Description( name = "train_randomforest_regression", value = "_FUNC_(double[] features, double target [, string options]) - " + "Returns a relation consists of " + "<int model_id, int model_type, string pred_model, array<double> var_importance, int oob_errors, int oob_tests>") public final class RandomForestRegressionUDTF extends UDTFWithOptions { private static final Log logger = LogFactory.getLog(RandomForestRegressionUDTF.class); private ListObjectInspector featureListOI; private PrimitiveObjectInspector featureElemOI; private PrimitiveObjectInspector targetOI; private List<double[]> featuresList; private DoubleArrayList targets; /** * The number of trees for each task */ private int _numTrees; /** * The number of random selected features */ private float _numVars; /** * The maximum number of the tree depth */ private int _maxDepth; /** * The maximum number of leaf nodes */ private int _maxLeafNodes; private int _minSamplesSplit; private int _minSamplesLeaf; private long _seed; private Attribute[] _attributes; private ModelType _outputType; @Nullable private Reporter _progressReporter; @Nullable private Counter _treeBuildTaskCounter; @Nullable private Counter _treeConstuctionTimeCounter; @Nullable private Counter _treeSerializationTimeCounter; @Override protected Options getOptions() { Options opts = new Options(); opts.addOption("trees", "num_trees", true, "The number of trees for each task [default: 50]"); opts.addOption("vars", "num_variables", true, "The number of random selected features [default: ceil(sqrt(x[0].length))]." + " int(num_variables * x[0].length) is considered if num_variable is (0,1]"); opts.addOption("depth", "max_depth", true, "The maximum number of the tree depth [default: Integer.MAX_VALUE]"); opts.addOption("leafs", "max_leaf_nodes", true, "The maximum number of leaf nodes [default: Integer.MAX_VALUE]"); opts.addOption("split", "min_split", true, "A node that has greater than or equals to `min_split` examples will split [default: 5]"); opts.addOption("min_samples_leaf", true, "The minimum number of samples in a leaf node [default: 1]"); opts.addOption("seed", true, "seed value in long [default: -1 (random)]"); opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types " + "(Q for quantitative variable and C for categorical variable. e.g., [Q,C,Q,C])"); opts.addOption("output", "output_type", true, "The output type (serialization/ser or opscode/vm or javascript/js) [default: serialization]"); opts.addOption("disable_compression", false, "Whether to disable compression of the output script [default: false]"); return opts; } @Override protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { int trees = 50, maxDepth = Integer.MAX_VALUE; int maxLeafs = Integer.MAX_VALUE, minSplit = 5, minSamplesLeaf = 1; float numVars = -1.f; Attribute[] attrs = null; long seed = -1L; String output = "serialization"; boolean compress = true; CommandLine cl = null; if (argOIs.length >= 3) { String rawArgs = HiveUtils.getConstString(argOIs[2]); cl = parseOptions(rawArgs); trees = Primitives.parseInt(cl.getOptionValue("num_trees"), trees); if (trees < 1) { throw new IllegalArgumentException("Invlaid number of trees: " + trees); } numVars = Primitives.parseFloat(cl.getOptionValue("num_variables"), numVars); maxDepth = Primitives.parseInt(cl.getOptionValue("max_depth"), maxDepth); maxLeafs = Primitives.parseInt(cl.getOptionValue("max_leaf_nodes"), maxLeafs); minSplit = Primitives.parseInt(cl.getOptionValue("min_split"), minSplit); minSamplesLeaf = Primitives.parseInt(cl.getOptionValue("min_samples_leaf"), minSamplesLeaf); seed = Primitives.parseLong(cl.getOptionValue("seed"), seed); attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types")); output = cl.getOptionValue("output", output); if (cl.hasOption("disable_compression")) { compress = false; } } this._numTrees = trees; this._numVars = numVars; this._maxDepth = maxDepth; this._maxLeafNodes = maxLeafs; this._minSamplesSplit = minSplit; this._minSamplesLeaf = minSamplesLeaf; this._seed = seed; this._attributes = attrs; this._outputType = ModelType.resolve(output, compress); return cl; } @Override public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { if (argOIs.length != 2 && argOIs.length != 3) { throw new UDFArgumentException( getClass().getSimpleName() + " takes 2 or 3 arguments: double[] features, double target [, const string options]: " + argOIs.length); } ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]); ObjectInspector elemOI = listOI.getListElementObjectInspector(); this.featureListOI = listOI; this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI); this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs[1]); processOptions(argOIs); this.featuresList = new ArrayList<double[]>(1024); this.targets = new DoubleArrayList(1024); ArrayList<String> fieldNames = new ArrayList<String>(5); ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(5); fieldNames.add("model_id"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); fieldNames.add("model_type"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); fieldNames.add("pred_model"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); fieldNames.add("var_importance"); fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); fieldNames.add("oob_errors"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); fieldNames.add("oob_tests"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); } @Override public void process(Object[] args) throws HiveException { if (args[0] == null) { throw new HiveException("array<double> features was null"); } double[] features = HiveUtils.asDoubleArray(args[0], featureListOI, featureElemOI); double target = PrimitiveObjectInspectorUtils.getDouble(args[1], targetOI); featuresList.add(features); targets.add(target); } @Override public void close() throws HiveException { this._progressReporter = getReporter(); this._treeBuildTaskCounter = (_progressReporter == null) ? null : _progressReporter.getCounter("hivemall.smile.RandomForestRegression$Counter", "Number of finished tree construction tasks"); this._treeConstuctionTimeCounter = (_progressReporter == null) ? null : _progressReporter.getCounter("hivemall.smile.RandomForestRegression$Counter", "Elapsed time in seconds for tree construction"); this._treeSerializationTimeCounter = (_progressReporter == null) ? null : _progressReporter.getCounter("hivemall.smile.RandomForestRegression$Counter", "Elapsed time in seconds for tree serialization"); reportProgress(_progressReporter); int numExamples = featuresList.size(); if (numExamples > 0) { double[][] x = featuresList.toArray(new double[numExamples][]); this.featuresList = null; double[] y = targets.toArray(); this.targets = null; // run training train(x, y); } // clean up this.featureListOI = null; this.featureElemOI = null; this.targetOI = null; this._attributes = null; } private void checkOptions() throws HiveException { if (_minSamplesSplit <= 0) { throw new HiveException("Invalid minSamplesSplit: " + _minSamplesSplit); } if (_maxDepth < 1) { throw new HiveException("Invalid maxDepth: " + _maxDepth); } } /** * @param x features * @param y label * @param attrs attribute types * @param numTrees The number of trees * @param _numVars The number of variables to pick up in each node. * @param _seed The seed number for Random Forest */ private void train(@Nonnull final double[][] x, @Nonnull final double[] y) throws HiveException { if (x.length != y.length) { throw new HiveException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length)); } checkOptions(); // Shuffle training samples SmileExtUtils.shuffle(x, y, _seed); Attribute[] attributes = SmileExtUtils.attributeTypes(_attributes, x); int numInputVars = SmileExtUtils.computeNumInputVars(_numVars, x); if (logger.isInfoEnabled()) { logger.info("numTrees: " + _numTrees + ", numVars: " + numInputVars + ", minSamplesSplit: " + _minSamplesSplit + ", maxDepth: " + _maxDepth + ", maxLeafs: " + _maxLeafNodes + ", nodeCapacity: " + _minSamplesSplit + ", seed: " + _seed); } int numExamples = x.length; double[] prediction = new double[numExamples]; // placeholder for out-of-bag prediction int[] oob = new int[numExamples]; int[][] order = SmileExtUtils.sort(attributes, x); AtomicInteger remainingTasks = new AtomicInteger(_numTrees); List<TrainingTask> tasks = new ArrayList<TrainingTask>(); for (int i = 0; i < _numTrees; i++) { long s = (_seed == -1L) ? -1L : _seed + i; tasks.add(new TrainingTask(this, i, attributes, x, y, numInputVars, order, prediction, oob, s, remainingTasks)); } MapredContext mapredContext = MapredContextAccessor.get(); final SmileTaskExecutor executor = new SmileTaskExecutor(mapredContext); try { executor.run(tasks); } catch (Exception ex) { throw new HiveException(ex); } finally { executor.shotdown(); } } /** * Synchronized because {@link #forward(Object)} should be called from a single thread. */ synchronized void forward(final int taskId, @Nonnull final Text model, @Nonnull final double[] importance, final double[] y, final double[] prediction, final int[] oob, final boolean lastTask) throws HiveException { double oobErrors = 0.d; int oobTests = 0; if (lastTask) { // out-of-bag error estimate for (int i = 0; i < y.length; i++) { if (oob[i] > 0) { oobTests++; double pred = prediction[i] / oob[i]; oobErrors += smile.math.Math.sqr(pred - y[i]); } } } String modelId = RandomUtils.getUUID(); final Object[] forwardObjs = new Object[6]; forwardObjs[0] = new Text(modelId); forwardObjs[1] = new IntWritable(_outputType.getId()); forwardObjs[2] = model; forwardObjs[3] = WritableUtils.toWritableList(importance); forwardObjs[4] = new DoubleWritable(oobErrors); forwardObjs[5] = new IntWritable(oobTests); forward(forwardObjs); reportProgress(_progressReporter); incrCounter(_treeBuildTaskCounter, 1); logger.info("Forwarded " + taskId + "-th RegressionTree out of " + _numTrees); } /** * Trains a regression tree. */ private static final class TrainingTask implements Callable<Integer> { /** * Attribute properties. */ private final Attribute[] _attributes; /** * Training instances. */ private final double[][] _x; /** * Training sample labels. */ private final double[] _y; /** * The index of training values in ascending order. Note that only numeric attributes will * be sorted. */ private final int[][] _order; /** * The number of variables to pick up in each node. */ private final int _numVars; /** * The out-of-bag predictions. */ private final double[] _prediction; /** * Out-of-bag sample */ private final int[] _oob; private final RandomForestRegressionUDTF _udtf; private final int _taskId; private final long _seed; private final AtomicInteger _remainingTasks; TrainingTask(RandomForestRegressionUDTF udtf, int taskId, Attribute[] attributes, double[][] x, double[] y, int numVars, int[][] order, double[] prediction, int[] oob, long seed, AtomicInteger remainingTasks) { this._udtf = udtf; this._taskId = taskId; this._attributes = attributes; this._x = x; this._y = y; this._order = order; this._numVars = numVars; this._prediction = prediction; this._oob = oob; this._seed = seed; this._remainingTasks = remainingTasks; } @Override public Integer call() throws HiveException { long s = (this._seed == -1L) ? SmileExtUtils.generateSeed() : new smile.math.Random( _seed).nextLong(); final smile.math.Random rnd1 = new smile.math.Random(s); final smile.math.Random rnd2 = new smile.math.Random(rnd1.nextLong()); final int N = _x.length; // Training samples draw with replacement. final int[] bags = new int[N]; final BitSet sampled = new BitSet(N); for (int i = 0; i < N; i++) { int index = rnd1.nextInt(N); bags[i] = index; sampled.set(index); } StopWatch stopwatch = new StopWatch(); RegressionTree tree = new RegressionTree(_attributes, _x, _y, _numVars, _udtf._maxDepth, _udtf._maxLeafNodes, _udtf._minSamplesSplit, _udtf._minSamplesLeaf, _order, bags, rnd2); incrCounter(_udtf._treeConstuctionTimeCounter, stopwatch.elapsed(TimeUnit.SECONDS)); // out-of-bag prediction for (int i = sampled.nextClearBit(0); i < N; i = sampled.nextClearBit(i + 1)) { double pred = tree.predict(_x[i]); synchronized (_x[i]) { _prediction[i] += pred; _oob[i]++; } } stopwatch.reset().start(); Text model = getModel(tree, _udtf._outputType); double[] importance = tree.importance(); tree = null; // help GC int remain = _remainingTasks.decrementAndGet(); boolean lastTask = (remain == 0); _udtf.forward(_taskId + 1, model, importance, _y, _prediction, _oob, lastTask); incrCounter(_udtf._treeSerializationTimeCounter, stopwatch.elapsed(TimeUnit.SECONDS)); return Integer.valueOf(remain); } private static Text getModel(@Nonnull final RegressionTree tree, @Nonnull final ModelType outputType) throws HiveException { final Text model; switch (outputType) { case serialization: case serialization_compressed: { byte[] b = tree.predictSerCodegen(outputType.isCompressed()); b = Base91.encode(b); model = new Text(b); break; } case opscode: case opscode_compressed: { String s = tree.predictOpCodegen(StackMachine.SEP); if (outputType.isCompressed()) { byte[] b = s.getBytes(); final DeflateCodec codec = new DeflateCodec(true, false); try { b = codec.compress(b); } catch (IOException e) { throw new HiveException("Failed to compressing a model", e); } finally { IOUtils.closeQuietly(codec); } b = Base91.encode(b); model = new Text(b); } else { model = new Text(s); } break; } case javascript: case javascript_compressed: { String s = tree.predictJsCodegen(); if (outputType.isCompressed()) { byte[] b = s.getBytes(); final DeflateCodec codec = new DeflateCodec(true, false); try { b = codec.compress(b); } catch (IOException e) { throw new HiveException("Failed to compressing a model", e); } finally { IOUtils.closeQuietly(codec); } b = Base91.encode(b); model = new Text(b); } else { model = new Text(s); } break; } default: throw new HiveException("Unexpected output type: " + outputType + ". Use javascript for the output instead"); } return model; } } }