/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.ftvec.amplify; import hivemall.UDTFWithOptions; import hivemall.common.RandomizedAmplifier; import hivemall.common.RandomizedAmplifier.DropoutListener; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.lang.Primitives; import java.util.ArrayList; import java.util.List; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; @Description(name = "rand_amplify", value = "_FUNC_(const int xtimes [, const string options], *)" + " - amplify the input records x-times in map-side") public final class RandomAmplifierUDTF extends UDTFWithOptions implements DropoutListener<Object[]> { private boolean hasOption = false; private long seed = -1L; private int numBuffers = 1000; private transient ObjectInspector[] argOIs; private transient RandomizedAmplifier<Object[]> amplifier; @Override protected Options getOptions() { Options opts = new Options(); opts.addOption("seed", true, "Random seed value [default: -1L (random)]"); opts.addOption("buf", "num_buffers", true, "The number of rows to keep in a buffer [default: 1000]"); return opts; } @Override protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { CommandLine cl = null; if (argOIs.length >= 3 && HiveUtils.isConstString(argOIs[1])) { String rawArgs = HiveUtils.getConstString(argOIs[1]); cl = parseOptions(rawArgs); this.hasOption = true; this.seed = Primitives.parseLong(cl.getOptionValue("seed"), this.seed); this.numBuffers = Primitives.parseInt(cl.getOptionValue("num_buffers"), this.numBuffers); } return cl; } @Override public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { final int numArgs = argOIs.length; if (numArgs < 2) { throw new UDFArgumentException( "_FUNC_(const int xtimes, [, const string options], *) takes at least two arguments"); } // xtimes int xtimes = HiveUtils.getAsConstInt(argOIs[0]); if (xtimes < 1) { throw new UDFArgumentException("Illegal xtimes value: " + xtimes); } this.argOIs = argOIs; processOptions(argOIs); this.amplifier = (seed == -1L) ? new RandomizedAmplifier<Object[]>(numBuffers, xtimes) : new RandomizedAmplifier<Object[]>(numBuffers, xtimes, seed); amplifier.setDropoutListener(this); final List<String> fieldNames = new ArrayList<String>(); final List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); final int argStartIndex = hasOption ? 2 : 1; for (int i = argStartIndex; i < numArgs; i++) { fieldNames.add("c" + (i - 1)); ObjectInspector rawOI = argOIs[i]; ObjectInspector retOI = ObjectInspectorUtils.getStandardObjectInspector(rawOI, ObjectInspectorCopyOption.DEFAULT); fieldOIs.add(retOI); } return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); } @Override public void process(Object[] args) throws HiveException { final int argStartIndex = hasOption ? 2 : 1; final Object[] row = new Object[args.length - argStartIndex]; for (int i = argStartIndex; i < args.length; i++) { Object arg = args[i]; ObjectInspector argOI = argOIs[i]; row[i - argStartIndex] = ObjectInspectorUtils.copyToStandardObject(arg, argOI, ObjectInspectorCopyOption.DEFAULT); } amplifier.add(row); } @Override public void close() throws HiveException { amplifier.sweepAll(); this.amplifier = null; } @Override public void onDrop(Object[] row) throws HiveException { forward(row); } }