/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.dataset;
import hivemall.UDTFWithOptions;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Comparator;
import java.util.Random;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
@Description(
name = "lr_datagen",
value = "_FUNC_(options string) - Generates a logistic regression dataset",
extended = "WITH dual AS (SELECT 1) SELECT lr_datagen('-n_examples 1k -n_features 10') FROM dual;")
public final class LogisticRegressionDataGeneratorUDTF extends UDTFWithOptions {
private static final int N_BUFFERS = 1000;
// control variable
private int position;
private float[] labels;
private String[][] featuresArray;
private Float[][] featuresFloatArray;
private int n_examples;
private int n_features;
private int n_dimensions;
private float eps;
private float prob_one;
private long r_seed;
private boolean dense;
private boolean sort;
private boolean classification;
private Random rnd1 = null, rnd2 = null;
@Override
protected Options getOptions() {
Options opts = new Options();
opts.addOption("ne", "n_examples", true,
"Number of training examples created for each task [DEFAULT: 1000]");
opts.addOption("nf", "n_features", true,
"Number of features contained for each example [DEFAULT: 10]");
opts.addOption("nd", "n_dims", true, "The size of feature dimensions [DEFAULT: 200]");
opts.addOption("eps", true,
"eps Epsilon factor by which positive examples are scaled [DEFAULT: 3.0]");
opts.addOption("p1", "prob_one", true,
" Probability in [0, 1.0) that a label is 1 [DEFAULT: 0.6]");
opts.addOption("seed", true, "The seed value for random number generator [DEFAULT: 43L]");
opts.addOption(
"dense",
false,
"Make a dense dataset or not. If not specified, a sparse dataset is generated.\n"
+ "For sparse, n_dims should be much larger than n_features. When disabled, n_features must be equals to n_dims ");
opts.addOption("sort", false, "Sort features if specified (used only for sparse dataset)");
opts.addOption("cl", "classification", false,
"Toggle this option on to generate a classification dataset");
return opts;
}
@Override
protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
if (argOIs.length != 1) {
throw new UDFArgumentException("Expected number of arguments is 1: " + argOIs.length);
}
String opts = HiveUtils.getConstString(argOIs[0]);
CommandLine cl = parseOptions(opts);
this.n_examples = NumberUtils.parseInt(cl.getOptionValue("n_examples"), 1000);
this.n_features = NumberUtils.parseInt(cl.getOptionValue("n_features"), 10);
this.n_dimensions = NumberUtils.parseInt(cl.getOptionValue("n_dims"), 200);
this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), 3.f);
this.prob_one = Primitives.parseFloat(cl.getOptionValue("prob_one"), 0.6f);
this.r_seed = Primitives.parseLong(cl.getOptionValue("seed"), 43L);
this.dense = cl.hasOption("dense");
this.sort = cl.hasOption("sort");
this.classification = cl.hasOption("classification");
if (n_features > n_dimensions) {
throw new UDFArgumentException("n_features '" + n_features
+ "' should be greater than or equals to n_dimensions '" + n_dimensions + "'");
}
if (dense) {
if (n_features != n_dimensions) {
throw new UDFArgumentException("n_features '" + n_features
+ "' must be equlas to n_dimensions '" + n_dimensions
+ "' when making a dense dataset");
}
}
return cl;
}
@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
processOptions(argOIs);
init();
ArrayList<String> fieldNames = new ArrayList<String>(2);
ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(2);
fieldNames.add("label");
fieldOIs.add(PrimitiveObjectInspectorFactory.javaFloatObjectInspector);
fieldNames.add("features");
if (dense) {
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaFloatObjectInspector));
} else {
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector));
}
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
private void init() {
this.labels = new float[N_BUFFERS];
if (dense) {
this.featuresFloatArray = new Float[N_BUFFERS][n_features];
} else {
this.featuresArray = new String[N_BUFFERS][n_features];
}
this.position = 0;
}
@Override
public void process(Object[] argOIs) throws HiveException {
if (rnd1 == null) {
assert (rnd2 == null);
final int taskid = HadoopUtils.getTaskId(-1);
final long seed;
if (taskid == -1) {
seed = r_seed; // Non-MR local task
} else {
seed = r_seed + taskid;
}
this.rnd1 = new Random(seed);
this.rnd2 = new Random(seed + 1);
}
for (int i = 0; i < n_examples; i++) {
if (dense) {
generateDenseData();
} else {
generateSparseData();
}
position++;
if (position == N_BUFFERS) {
flushBuffered(position);
this.position = 0;
}
}
}
private void generateSparseData() throws HiveException {
float label = rnd1.nextFloat();
float sign = (label <= prob_one) ? 1.f : 0.f;
labels[position] = classification ? sign : label;
String[] features = featuresArray[position];
assert (features != null);
final BitSet used = new BitSet(n_dimensions);
int searchClearBitsFrom = 0;
for (int i = 0, retry = 0; i < n_features; i++) {
int f = rnd2.nextInt(n_dimensions);
if (used.get(f)) {
if (retry < 3) {
--i;
++retry;
continue;
}
searchClearBitsFrom = used.nextClearBit(searchClearBitsFrom);
f = searchClearBitsFrom;
}
used.set(f);
float w = (float) rnd2.nextGaussian() + (sign * eps);
String y = f + ":" + w;
features[i] = y;
retry = 0;
}
if (sort) {
Arrays.sort(features, new Comparator<String>() {
@Override
public int compare(String o1, String o2) {
int i1 = Integer.parseInt(o1.split(":")[0]);
int i2 = Integer.parseInt(o2.split(":")[0]);
return Primitives.compare(i1, i2);
}
});
}
}
private void generateDenseData() {
float label = rnd1.nextFloat();
float sign = (label >= prob_one) ? 1.f : 0.f;
labels[position] = classification ? sign : label;
Float[] features = featuresFloatArray[position];
assert (features != null);
for (int i = 0; i < n_features; i++) {
float w = (float) rnd2.nextGaussian() + (sign * eps);
features[i] = Float.valueOf(w);
}
}
private void flushBuffered(int position) throws HiveException {
final Object[] forwardObjs = new Object[2];
if (dense) {
for (int i = 0; i < position; i++) {
forwardObjs[0] = Float.valueOf(labels[i]);
forwardObjs[1] = Arrays.asList(featuresFloatArray[i]);
forward(forwardObjs);
}
} else {
for (int i = 0; i < position; i++) {
forwardObjs[0] = Float.valueOf(labels[i]);
forwardObjs[1] = Arrays.asList(featuresArray[i]);
forward(forwardObjs);
}
}
}
@Override
public void close() throws HiveException {
if (position > 0) {
flushBuffered(position);
}
// release resources to help GCs
this.labels = null;
this.featuresArray = null;
this.featuresFloatArray = null;
}
}