/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.udf;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import ml.shifu.shifu.container.WeightAmplifier;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ModelNormalizeConf.NormType;
import ml.shifu.shifu.core.DataSampler;
import ml.shifu.shifu.core.Normalizer;
import ml.shifu.shifu.util.CommonUtils;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.jexl2.Expression;
import org.apache.commons.jexl2.JexlContext;
import org.apache.commons.jexl2.JexlEngine;
import org.apache.commons.jexl2.MapContext;
import org.apache.commons.lang.StringUtils;
import org.apache.pig.data.Tuple;
import org.apache.pig.data.TupleFactory;
import org.apache.pig.impl.logicalLayer.schema.Schema;
import org.apache.pig.impl.util.Utils;
/**
* For parquet format, only double type data will be saved. Not string like in {@link NormalizeUDF}. TODO, should merge
* together with {@link NormalizeUDF}.
*/
public class NormalizeParquetUDF extends AbstractTrainerUDF<Tuple> {
private List<String> negTags;
private List<String> posTags;
private Double cutoff;
private NormType normType;
private Expression weightExpr;
private JexlContext weightContext;
private DecimalFormat df = new DecimalFormat("#.######");
private String alg;
// private DecimalFormat df = new DecimalFormat("#.######");
public NormalizeParquetUDF(String source, String pathModelConfig, String pathColumnConfig) throws Exception {
super(source, pathModelConfig, pathColumnConfig);
log.debug("Initializing NormalizeUDF ... ");
negTags = modelConfig.getNegTags();
log.debug("\t Negative Tags: " + negTags);
posTags = modelConfig.getPosTags();
log.debug("\t Positive Tags: " + posTags);
cutoff = modelConfig.getNormalizeStdDevCutOff();
log.debug("\t stdDevCutOff: " + cutoff);
normType = modelConfig.getNormalizeType();
log.debug("\t normType: " + normType.name());
weightExpr = createExpression(modelConfig.getWeightColumnName());
if(weightExpr != null) {
weightContext = new MapContext();
}
log.debug("NormalizeUDF Initialized");
this.alg = this.modelConfig.getAlgorithm();
}
public Tuple exec(Tuple input) throws IOException {
if(input == null || input.size() == 0) {
return null;
}
// do data sampling. Unselected data or data with invalid tag will be filtered out.
final String rawTag = CommonUtils.trimTag(input.get(tagColumnNum).toString());
boolean isNotSampled = DataSampler.isNotSampled(posTags, negTags, modelConfig.getNormalizeSampleRate(),
modelConfig.isNormalizeSampleNegOnly(), rawTag);
if(isNotSampled) {
return null;
}
// append tuple with tag, normalized value.
Tuple tuple = TupleFactory.getInstance().newTuple();
final NormType normType = modelConfig.getNormalizeType();
for(int i = 0; i < input.size(); i++) {
ColumnConfig config = columnConfigList.get(i);
String val = (input.get(i) == null) ? "" : input.get(i).toString();
// load variables for weight calculating.
if(weightExpr != null) {
weightContext.set(config.getColumnName(), val);
}
// check tag type.
if(tagColumnNum == i) {
String tagType = tagTypeCheck(posTags, negTags, rawTag);
if(tagType == null) {
log.error("Invalid data! The target value is not listed - " + rawTag);
return null;
}
tuple.append(Integer.parseInt(tagType));
continue;
}
// append normalize data.
if(!CommonUtils.isGoodCandidate(config)) {
tuple.append((Double) null);
} else {
if(CommonUtils.isTreeModel(this.alg)) {
Double normVal = 0d;
if(config.isCategorical()) {
tuple.append(val);
} else {
try {
normVal = Double.parseDouble(val);
} catch (Exception e) {
log.debug("Not decimal format " + val + ", using default!");
normVal = Normalizer.defaultMissingValue(config);
}
}
tuple.append(df.format(normVal));
} else {
Double normVal = Normalizer.normalize(config, val, cutoff, normType);
tuple.append(df.format(normVal));
}
}
}
// append tuple with weight.
double weight = evaluateWeight(weightExpr, weightContext);
tuple.append(weight);
return tuple;
}
/*
* Evaluate weight expression based on the variables context.
*
* @param expr
* - weight evaluation expression
* @param jc
* - A JexlContext containing variables for weight expression.
* @return The result of this evaluation
*/
public double evaluateWeight(Expression expr, JexlContext jc) {
double weight = 1.0d;
if(expr != null) {
Object result = expr.evaluate(jc);
if(result instanceof Integer) {
weight = ((Integer) result).doubleValue();
} else if(result instanceof Double) {
weight = ((Double) result).doubleValue();
} else if(result instanceof String) {
try {
weight = Double.parseDouble((String) result);
} catch (NumberFormatException e) {
// Not a number, use default
if(System.currentTimeMillis() % 100 == 0) {
log.warn("Weight column type is String and value cannot be parsed with " + result
+ ", use default 1.0d");
}
weight = 1.0d;
}
}
}
return weight;
}
/*
* Check tag type.
*
* @param posTags
* - positive tag list.
* @param negTags
* - negtive tag list.
* @param rawTag
* - raw tag string
* @return tag type String. Return "1" for positive tag. Return "0" for negtive tag. Return null for invalid tag.
*/
public String tagTypeCheck(List<String> posTags, List<String> negTags, String rawTag) {
String type = null;
if(posTags.contains(rawTag)) {
type = "1";
} else if(negTags.contains(rawTag)) {
type = "0";
}
return type;
}
public Schema outputSchema(Schema input) {
try {
StringBuilder schemaStr = new StringBuilder();
schemaStr.append("Normalized:Tuple(");
for(int i = 0; i < columnConfigList.size(); i++) {
ColumnConfig config = this.columnConfigList.get(i);
if(tagColumnNum == i) {
schemaStr.append(config.getColumnName() + ":float" + ",");
} else {
if(config.isCategorical() && CommonUtils.isTreeModel(this.alg)) {
schemaStr.append(config.getColumnName() + ":chararray" + ",");
} else {
schemaStr.append(config.getColumnName() + ":float" + ",");
}
}
}
schemaStr.append("weight:float)");
return Utils.getSchemaFromString(schemaStr.toString());
} catch (Exception e) {
log.error("error in outputSchema", e);
return null;
}
}
/*
* Create expressions for multi weight settings
*/
protected Map<Expression, Double> createExpressionMap(List<WeightAmplifier> weightExprList) {
Map<Expression, Double> ewMap = new HashMap<Expression, Double>();
if(CollectionUtils.isNotEmpty(weightExprList)) {
JexlEngine jexl = new JexlEngine();
for(WeightAmplifier we: weightExprList) {
ewMap.put(jexl.createExpression(we.getTargetExpression()), we.getTargetWeight());
}
}
return ewMap;
}
/*
* Create the expression for weight setting
*/
private Expression createExpression(String weightAmplifier) {
if(StringUtils.isNotBlank(weightAmplifier)) {
JexlEngine jexl = new JexlEngine();
return jexl.createExpression(weightAmplifier);
}
return null;
}
}