/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.classifier.multiclass;
import static hivemall.HivemallConstants.BIGINT_TYPE_NAME;
import static hivemall.HivemallConstants.INT_TYPE_NAME;
import static hivemall.HivemallConstants.STRING_TYPE_NAME;
import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableFloatObjectInspector;
import hivemall.LearnerBaseUDTF;
import hivemall.model.FeatureValue;
import hivemall.model.IWeightValue;
import hivemall.model.Margin;
import hivemall.model.PredictionModel;
import hivemall.model.PredictionResult;
import hivemall.model.WeightValue;
import hivemall.model.WeightValue.WeightValueWithCovar;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.datetime.StopWatch;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.IOUtils;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.SerDeException;
import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.FloatObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableFloatObjectInspector;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.Text;
public abstract class MulticlassOnlineClassifierUDTF extends LearnerBaseUDTF {
private static final Log logger = LogFactory.getLog(MulticlassOnlineClassifierUDTF.class);
private ListObjectInspector featureListOI;
private boolean parseFeature;
private PrimitiveObjectInspector labelInputOI;
protected Map<Object, PredictionModel> label2model;
protected int count;
@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
if (argOIs.length < 2) {
throw new UDFArgumentException(
getClass().getSimpleName()
+ " takes 2 arguments: List<Int|BigInt|Text> features, {Int|BitInt|Text} label [, constant text options]");
}
PrimitiveObjectInspector featureInputOI = processFeaturesOI(argOIs[0]);
this.labelInputOI = HiveUtils.asPrimitiveObjectInspector(argOIs[1]);
String labelTypeName = labelInputOI.getTypeName();
if (!STRING_TYPE_NAME.equals(labelTypeName) && !INT_TYPE_NAME.equals(labelTypeName)
&& !BIGINT_TYPE_NAME.equals(labelTypeName)) {
throw new UDFArgumentTypeException(0, "label must be a type [Int|BigInt|Text]: "
+ labelTypeName);
}
processOptions(argOIs);
PrimitiveObjectInspector featureOutputOI = dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector
: featureInputOI;
this.label2model = new HashMap<Object, PredictionModel>(64);
if (preloadedModelFile != null) {
loadPredictionModel(label2model, preloadedModelFile, labelInputOI, featureOutputOI);
}
this.count = 0;
return getReturnOI(labelInputOI, featureOutputOI);
}
@Override
protected int getInitialModelSize() {
return 8192;
}
protected PrimitiveObjectInspector processFeaturesOI(ObjectInspector arg)
throws UDFArgumentException {
this.featureListOI = (ListObjectInspector) arg;
ObjectInspector featureRawOI = featureListOI.getListElementObjectInspector();
String keyTypeName = featureRawOI.getTypeName();
if (!STRING_TYPE_NAME.equals(keyTypeName) && !INT_TYPE_NAME.equals(keyTypeName)
&& !BIGINT_TYPE_NAME.equals(keyTypeName)) {
throw new UDFArgumentTypeException(0,
"1st argument must be Map of key type [Int|BitInt|Text]: " + keyTypeName);
}
this.parseFeature = STRING_TYPE_NAME.equals(keyTypeName);
return HiveUtils.asPrimitiveObjectInspector(featureRawOI);
}
protected StructObjectInspector getReturnOI(ObjectInspector labelRawOI,
ObjectInspector featureRawOI) {
ArrayList<String> fieldNames = new ArrayList<String>();
ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
fieldNames.add("label");
ObjectInspector labelOI = ObjectInspectorUtils.getStandardObjectInspector(labelRawOI);
fieldOIs.add(labelOI);
fieldNames.add("feature");
ObjectInspector featureOI = ObjectInspectorUtils.getStandardObjectInspector(featureRawOI);
fieldOIs.add(featureOI);
fieldNames.add("weight");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
if (useCovariance()) {
fieldNames.add("covar");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
}
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
@Override
public void process(Object[] args) throws HiveException {
List<?> features = (List<?>) featureListOI.getList(args[0]);
FeatureValue[] featureVector = parseFeatures(features);
if (featureVector == null) {
return;
}
Object label = ObjectInspectorUtils.copyToStandardObject(args[1], labelInputOI);
if (label == null) {
throw new UDFArgumentException("label value must not be NULL");
}
count++;
train(featureVector, label);
}
@Nullable
protected final FeatureValue[] parseFeatures(@Nonnull final List<?> features) {
final int size = features.size();
if (size == 0) {
return null;
}
final ObjectInspector featureInspector = featureListOI.getListElementObjectInspector();
final FeatureValue[] featureVector = new FeatureValue[size];
for (int i = 0; i < size; i++) {
Object f = features.get(i);
if (f == null) {
continue;
}
final FeatureValue fv;
if (parseFeature) {
fv = FeatureValue.parse(f);
} else {
Object k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
fv = new FeatureValue(k, 1.f);
}
featureVector[i] = fv;
}
return featureVector;
}
protected abstract void train(@Nonnull final FeatureValue[] features,
@Nonnull final Object actual_label);
protected final PredictionResult classify(@Nonnull final FeatureValue[] features) {
float maxScore = Float.MIN_VALUE;
Object maxScoredLabel = null;
for (Map.Entry<Object, PredictionModel> label2map : label2model.entrySet()) {// for each class
Object label = label2map.getKey();
PredictionModel model = label2map.getValue();
float score = calcScore(model, features);
if (maxScoredLabel == null || score > maxScore) {
maxScore = score;
maxScoredLabel = label;
}
}
return new PredictionResult(maxScoredLabel, maxScore);
}
protected Margin getMargin(@Nonnull final FeatureValue[] features, final Object actual_label) {
float correctScore = 0.f;
Object maxAnotherLabel = null;
float maxAnotherScore = 0.f;
for (Map.Entry<Object, PredictionModel> label2map : label2model.entrySet()) {// for each class
Object label = label2map.getKey();
PredictionModel model = label2map.getValue();
float score = calcScore(model, features);
if (label.equals(actual_label)) {
correctScore = score;
} else {
if (maxAnotherLabel == null || score > maxAnotherScore) {
maxAnotherLabel = label;
maxAnotherScore = score;
}
}
}
return new Margin(correctScore, maxAnotherLabel, maxAnotherScore);
}
protected Margin getMarginAndVariance(@Nonnull final FeatureValue[] features,
final Object actual_label) {
return getMarginAndVariance(features, actual_label, false);
}
protected Margin getMarginAndVariance(@Nonnull final FeatureValue[] features,
final Object actual_label, boolean nonZeroVariance) {
float correctScore = 0.f;
float correctVariance = 0.f;
Object maxAnotherLabel = null;
float maxAnotherScore = 0.f;
float maxAnotherVariance = 0.f;
if (nonZeroVariance && label2model.isEmpty()) {// for initial call
float var = 2.f * calcVariance(features);
return new Margin(correctScore, maxAnotherLabel, maxAnotherScore).variance(var);
}
for (Map.Entry<Object, PredictionModel> label2map : label2model.entrySet()) {// for each class
Object label = label2map.getKey();
PredictionModel model = label2map.getValue();
PredictionResult predicted = calcScoreAndVariance(model, features);
float score = predicted.getScore();
if (label.equals(actual_label)) {
correctScore = score;
correctVariance = predicted.getVariance();
} else {
if (maxAnotherLabel == null || score > maxAnotherScore) {
maxAnotherLabel = label;
maxAnotherScore = score;
maxAnotherVariance = predicted.getVariance();
}
}
}
float var = correctVariance + maxAnotherVariance;
return new Margin(correctScore, maxAnotherLabel, maxAnotherScore).variance(var);
}
protected final float squaredNorm(@Nonnull final FeatureValue[] features) {
float squared_norm = 0.f;
for (FeatureValue f : features) {// a += w[i] * x[i]
if (f == null) {
continue;
}
final float v = f.getValueAsFloat();
squared_norm += (v * v);
}
return squared_norm;
}
protected final float calcScore(@Nonnull final PredictionModel model,
@Nonnull final FeatureValue[] features) {
float score = 0.f;
for (FeatureValue f : features) {// a += w[i] * x[i]
if (f == null) {
continue;
}
final Object k = f.getFeature();
final float v = f.getValueAsFloat();
float old_w = model.getWeight(k);
if (old_w != 0f) {
score += (old_w * v);
}
}
return score;
}
protected final float calcVariance(@Nonnull final FeatureValue[] features) {
float variance = 0.f;
for (FeatureValue f : features) {// a += w[i] * x[i]
if (f == null) {
continue;
}
float v = f.getValueAsFloat();
variance += v * v;
}
return variance;
}
protected final PredictionResult calcScoreAndVariance(@Nonnull final PredictionModel model,
@Nonnull final FeatureValue[] features) {
float score = 0.f;
float variance = 0.f;
for (FeatureValue f : features) {// a += w[i] * x[i]
if (f == null) {
continue;
}
final Object k = f.getFeature();
final float v = f.getValueAsFloat();
IWeightValue old_w = model.get(k);
if (old_w == null) {
variance += (1.f * v * v);
} else {
score += (old_w.get() * v);
variance += (old_w.getCovariance() * v * v);
}
}
return new PredictionResult(score).variance(variance);
}
protected void update(@Nonnull final FeatureValue[] features, float coeff, Object actual_label,
Object missed_label) {
assert (actual_label != null);
if (actual_label.equals(missed_label)) {
throw new IllegalArgumentException("Actual label equals to missed label: "
+ actual_label);
}
PredictionModel model2add = label2model.get(actual_label);
if (model2add == null) {
model2add = createModel();
label2model.put(actual_label, model2add);
}
PredictionModel model2sub = null;
if (missed_label != null) {
model2sub = label2model.get(missed_label);
if (model2sub == null) {
model2sub = createModel();
label2model.put(missed_label, model2sub);
}
}
for (FeatureValue f : features) {// w[f] += y * x[f]
if (f == null) {
continue;
}
final Object k = f.getFeature();
final float v = f.getValueAsFloat();
float old_trueclass_w = model2add.getWeight(k);
float add_w = old_trueclass_w + (coeff * v);
model2add.set(k, new WeightValue(add_w));
if (model2sub != null) {
float old_falseclass_w = model2sub.getWeight(k);
float sub_w = old_falseclass_w - (coeff * v);
model2sub.set(k, new WeightValue(sub_w));
}
}
}
@Override
public final void close() throws HiveException {
super.close();
if (label2model != null) {
long numForwarded = 0L;
long numMixed = 0L;
if (useCovariance()) {
final WeightValueWithCovar probe = new WeightValueWithCovar();
final Object[] forwardMapObj = new Object[4];
final FloatWritable fv = new FloatWritable();
final FloatWritable cov = new FloatWritable();
for (Map.Entry<Object, PredictionModel> entry : label2model.entrySet()) {
Object label = entry.getKey();
forwardMapObj[0] = label;
PredictionModel model = entry.getValue();
numMixed += model.getNumMixed();
IMapIterator<Object, IWeightValue> itor = model.entries();
while (itor.next() != -1) {
itor.getValue(probe);
if (!probe.isTouched()) {
continue; // skip outputting untouched weights
}
Object k = itor.getKey();
fv.set(probe.get());
cov.set(probe.getCovariance());
forwardMapObj[1] = k;
forwardMapObj[2] = fv;
forwardMapObj[3] = cov;
forward(forwardMapObj);
numForwarded++;
}
}
} else {
final WeightValue probe = new WeightValue();
final Object[] forwardMapObj = new Object[3];
final FloatWritable fv = new FloatWritable();
for (Map.Entry<Object, PredictionModel> entry : label2model.entrySet()) {
Object label = entry.getKey();
forwardMapObj[0] = label;
PredictionModel model = entry.getValue();
numMixed += model.getNumMixed();
IMapIterator<Object, IWeightValue> itor = model.entries();
while (itor.next() != -1) {
itor.getValue(probe);
if (!probe.isTouched()) {
continue; // skip outputting untouched weights
}
Object k = itor.getKey();
fv.set(probe.get());
forwardMapObj[1] = k;
forwardMapObj[2] = fv;
forward(forwardMapObj);
numForwarded++;
}
}
}
this.label2model = null;
logger.info("Trained a prediction model using " + count + " training examples"
+ (numMixed > 0 ? "( numMixed: " + numMixed + " )" : ""));
logger.info("Forwarded the prediction model of " + numForwarded + " rows");
}
}
protected void loadPredictionModel(Map<Object, PredictionModel> label2model, String filename,
PrimitiveObjectInspector labelOI, PrimitiveObjectInspector featureOI) {
final StopWatch elapsed = new StopWatch();
final long lines;
try {
if (useCovariance()) {
lines = loadPredictionModel(label2model, new File(filename), labelOI, featureOI,
writableFloatObjectInspector, writableFloatObjectInspector);
} else {
lines = loadPredictionModel(label2model, new File(filename), labelOI, featureOI,
writableFloatObjectInspector);
}
} catch (IOException e) {
throw new RuntimeException("Failed to load a model: " + filename, e);
} catch (SerDeException e) {
throw new RuntimeException("Failed to load a model: " + filename, e);
}
if (!label2model.isEmpty()) {
long totalFeatures = 0L;
StringBuilder statsBuf = new StringBuilder(256);
for (Map.Entry<Object, PredictionModel> e : label2model.entrySet()) {
Object label = e.getKey();
int numFeatures = e.getValue().size();
statsBuf.append('\n')
.append("Label: ")
.append(label)
.append(", Number of Features: ")
.append(numFeatures);
totalFeatures += numFeatures;
}
logger.info("Loaded total " + totalFeatures + " features from distributed cache '"
+ filename + "' (" + lines + " lines) in " + elapsed + statsBuf);
}
}
private long loadPredictionModel(Map<Object, PredictionModel> label2model, File file,
PrimitiveObjectInspector labelOI, PrimitiveObjectInspector featureOI,
WritableFloatObjectInspector weightOI) throws IOException, SerDeException {
long count = 0L;
if (!file.exists()) {
return count;
}
if (!file.getName().endsWith(".crc")) {
if (file.isDirectory()) {
for (File f : file.listFiles()) {
count += loadPredictionModel(label2model, f, labelOI, featureOI, weightOI);
}
} else {
LazySimpleSerDe serde = HiveUtils.getLineSerde(labelOI, featureOI, weightOI);
StructObjectInspector lineOI = (StructObjectInspector) serde.getObjectInspector();
StructField c1ref = lineOI.getStructFieldRef("c1");
StructField c2ref = lineOI.getStructFieldRef("c2");
StructField c3ref = lineOI.getStructFieldRef("c3");
PrimitiveObjectInspector c1refOI = (PrimitiveObjectInspector) c1ref.getFieldObjectInspector();
PrimitiveObjectInspector c2refOI = (PrimitiveObjectInspector) c2ref.getFieldObjectInspector();
FloatObjectInspector c3refOI = (FloatObjectInspector) c3ref.getFieldObjectInspector();
BufferedReader reader = null;
try {
reader = HadoopUtils.getBufferedReader(file);
String line;
while ((line = reader.readLine()) != null) {
count++;
Text lineText = new Text(line);
Object lineObj = serde.deserialize(lineText);
List<Object> fields = lineOI.getStructFieldsDataAsList(lineObj);
Object f0 = fields.get(0);
Object f1 = fields.get(1);
Object f2 = fields.get(2);
if (f0 == null || f1 == null || f2 == null) {
continue; // avoid the case that key or value is null
}
Object label = c1refOI.getPrimitiveWritableObject(c1refOI.copyObject(f0));
PredictionModel model = label2model.get(label);
if (model == null) {
model = createModel();
label2model.put(label, model);
}
Object k = c2refOI.getPrimitiveWritableObject(c2refOI.copyObject(f1));
float v = c3refOI.get(f2);
model.set(k, new WeightValue(v, false));
}
} finally {
IOUtils.closeQuietly(reader);
}
}
}
return count;
}
private long loadPredictionModel(Map<Object, PredictionModel> label2model, File file,
PrimitiveObjectInspector labelOI, PrimitiveObjectInspector featureOI,
WritableFloatObjectInspector weightOI, WritableFloatObjectInspector covarOI)
throws IOException, SerDeException {
long count = 0L;
if (!file.exists()) {
return count;
}
if (!file.getName().endsWith(".crc")) {
if (file.isDirectory()) {
for (File f : file.listFiles()) {
count += loadPredictionModel(label2model, f, labelOI, featureOI, weightOI,
covarOI);
}
} else {
LazySimpleSerDe serde = HiveUtils.getLineSerde(labelOI, featureOI, weightOI,
covarOI);
StructObjectInspector lineOI = (StructObjectInspector) serde.getObjectInspector();
StructField c1ref = lineOI.getStructFieldRef("c1");
StructField c2ref = lineOI.getStructFieldRef("c2");
StructField c3ref = lineOI.getStructFieldRef("c3");
StructField c4ref = lineOI.getStructFieldRef("c4");
PrimitiveObjectInspector c1refOI = (PrimitiveObjectInspector) c1ref.getFieldObjectInspector();
PrimitiveObjectInspector c2refOI = (PrimitiveObjectInspector) c2ref.getFieldObjectInspector();
FloatObjectInspector c3refOI = (FloatObjectInspector) c3ref.getFieldObjectInspector();
FloatObjectInspector c4refOI = (FloatObjectInspector) c4ref.getFieldObjectInspector();
BufferedReader reader = null;
try {
reader = HadoopUtils.getBufferedReader(file);
String line;
while ((line = reader.readLine()) != null) {
count++;
Text lineText = new Text(line);
Object lineObj = serde.deserialize(lineText);
List<Object> fields = lineOI.getStructFieldsDataAsList(lineObj);
Object f0 = fields.get(0);
Object f1 = fields.get(1);
Object f2 = fields.get(2);
Object f3 = fields.get(3);
if (f0 == null || f1 == null || f2 == null) {
continue; // avoid unexpected case
}
Object label = c1refOI.getPrimitiveWritableObject(c1refOI.copyObject(f0));
PredictionModel model = label2model.get(label);
if (model == null) {
model = createModel();
label2model.put(label, model);
}
Object k = c2refOI.getPrimitiveWritableObject(c2refOI.copyObject(f1));
float v = c3refOI.get(f2);
float cov = (f3 == null) ? WeightValueWithCovar.DEFAULT_COVAR
: c4refOI.get(f3);
model.set(k, new WeightValueWithCovar(v, cov, false));
}
} finally {
IOUtils.closeQuietly(reader);
}
}
}
return count;
}
}