/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.udf; import java.io.IOException; import java.text.DecimalFormat; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import ml.shifu.shifu.column.NSColumn; import ml.shifu.shifu.container.WeightAmplifier; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelNormalizeConf.NormType; import ml.shifu.shifu.core.DataSampler; import ml.shifu.shifu.core.Normalizer; import ml.shifu.shifu.util.CommonUtils; import ml.shifu.shifu.util.Constants; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.jexl2.Expression; import org.apache.commons.jexl2.JexlContext; import org.apache.commons.jexl2.JexlEngine; import org.apache.commons.jexl2.MapContext; import org.apache.commons.lang.StringUtils; import org.apache.pig.data.Tuple; import org.apache.pig.data.TupleFactory; import org.apache.pig.impl.logicalLayer.schema.Schema; import org.apache.pig.impl.util.Utils; import org.apache.pig.tools.pigstats.PigStatusReporter; /** * NormalizeUDF class normalize the training data for parquet format. */ public class NormalizeUDF extends AbstractTrainerUDF<Tuple> { private List<Set<String>> tags; private Double cutoff; private NormType normType; private Expression weightExpr; private JexlContext weightContext; private DecimalFormat df = new DecimalFormat("#.######"); /** * For categorical feature, a map is used to save query time in execution */ private Map<Integer, Map<String, Integer>> categoricalIndexMap = new HashMap<Integer, Map<String, Integer>>(); public static enum WarnInNormalizeUDF { INVALID_TAG; }; // if current norm for only clean and not transform categorical and numeric value private boolean isForClean = false; public NormalizeUDF(String source, String pathModelConfig, String pathColumnConfig) throws Exception { this(source, pathModelConfig, pathColumnConfig, "false"); } public NormalizeUDF(String source, String pathModelConfig, String pathColumnConfig, String isForClean) throws Exception { super(source, pathModelConfig, pathColumnConfig); this.isForClean = "true".equalsIgnoreCase(isForClean); log.debug("Initializing NormalizeUDF ... "); cutoff = modelConfig.getNormalizeStdDevCutOff(); log.debug("\t stdDevCutOff: " + cutoff); normType = modelConfig.getNormalizeType(); log.debug("\t normType: " + normType.name()); weightExpr = createExpression(modelConfig.getWeightColumnName()); if(weightExpr != null) { weightContext = new MapContext(); } this.tags = super.modelConfig.getSetTags(); for(ColumnConfig config: columnConfigList) { if(config.isCategorical()) { Map<String, Integer> map = new HashMap<String, Integer>(); if(config.getBinCategory() != null) { for(int i = 0; i < config.getBinCategory().size(); i++) { List<String> catValues = CommonUtils.flattenCatValGrp(config.getBinCategory().get(i)); for(String cval: catValues) { map.put(cval, i); } } } this.categoricalIndexMap.put(config.getColumnNum(), map); } } log.debug("NormalizeUDF Initialized"); } @SuppressWarnings("deprecation") public Tuple exec(Tuple input) throws IOException { if(input == null || input.size() == 0) { return null; } // update total valid count if(isPigEnabled(Constants.SHIFU_GROUP_COUNTER, "TOTAL_VALID_COUNT")) { PigStatusReporter.getInstance().getCounter(Constants.SHIFU_GROUP_COUNTER, "TOTAL_VALID_COUNT").increment(1); } final String rawTag = CommonUtils.trimTag(input.get(tagColumnNum).toString()); // make sure all invalid tag record are filter out if(!super.tagSet.contains(rawTag)) { if(isPigEnabled(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG")) { PigStatusReporter.getInstance().getCounter(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG").increment(1); } return null; } // data sampling only for normalization, for data cleaning, shouldn't do data sampling if(!this.isForClean) { // do data sampling. Unselected data or data with invalid tag will be filtered out. boolean isNotSampled = DataSampler.isNotSampled(modelConfig.isRegression(), super.tagSet, super.posTagSet, super.negTagSet, modelConfig.getNormalizeSampleRate(), modelConfig.isNormalizeSampleNegOnly(), rawTag); if(isNotSampled) { return null; } } // append tuple with tag, normalized value. Tuple tuple = TupleFactory.getInstance().newTuple(); final NormType normType = modelConfig.getNormalizeType(); for(int i = 0; i < input.size(); i++) { ColumnConfig config = columnConfigList.get(i); String val = (input.get(i) == null) ? "" : input.get(i).toString().trim(); // load variables for weight calculating. if(weightExpr != null) { weightContext.set(new NSColumn(config.getColumnName()).getSimpleName(), val); } // check tag type. if(tagColumnNum == i) { if(modelConfig.isRegression()) { int type = 0; if(super.posTagSet.contains(rawTag)) { type = 1; } else if(super.negTagSet.contains(rawTag)) { type = 0; } else { log.error("Invalid data! The target value is not listed - " + rawTag); warn("Invalid data! The target value is not listed - " + rawTag, WarnInNormalizeUDF.INVALID_TAG); return null; } tuple.append(type); } else { int index = -1; for(int j = 0; j < tags.size(); j++) { Set<String> tagSet = tags.get(j); if(tagSet.contains(rawTag)) { index = j; break; } } if(index == -1) { log.error("Invalid data! The target value is not listed - " + rawTag); return null; } tuple.append(index); } continue; } if(this.isForClean) { // for RF/GBT model, only clean data, not real do norm data if(config.isCategorical()) { Map<String, Integer> map = this.categoricalIndexMap.get(config.getColumnNum()); // map should not be null, no need check if map is null, if val not in binCategory, set it to "" tuple.append(((map.get(val) == null || map.get(val) == -1)) ? "" : val); } else { Double normVal = 0d; try { normVal = Double.parseDouble(val); } catch (Exception e) { log.debug("Not decimal format " + val + ", using default!"); normVal = Normalizer.defaultMissingValue(config); } tuple.append(df.format(normVal)); } } else { // append normalize data. exclude data clean, for data cleaning, no need check good or bad candidate if(CommonUtils.isGoodCandidate(modelConfig.isRegression(), config)) { // for multiple classification, binPosRate means rate of such category over all counts, reuse // binPosRate for normalize Double normVal = Normalizer.normalize(config, val, cutoff, normType); tuple.append(df.format(normVal)); } else { tuple.append(config.isMeta() ? val : null); } } } // append tuple with weight. double weight = evaluateWeight(weightExpr, weightContext); tuple.append(weight); return tuple; } /** * Evaluate weight expression based on the variables context. * * @param expr * - weight evaluation expression * @param jc * - A JexlContext containing variables for weight expression. * @return The result of this evaluation */ public double evaluateWeight(Expression expr, JexlContext jc) { double weight = 1.0d; if(expr != null) { Object result = expr.evaluate(jc); if(result instanceof Integer) { weight = ((Integer) result).doubleValue(); } else if(result instanceof Double) { weight = ((Double) result).doubleValue(); } else if(result instanceof String) { try { weight = Double.parseDouble((String) result); } catch (NumberFormatException e) { // Not a number, use default if(System.currentTimeMillis() % 100 == 0) { log.warn("Weight column type is String and value cannot be parsed with " + result + ", use default 1.0d"); } weight = 1.0d; } } } return weight; } public Schema outputSchema(Schema input) { try { StringBuilder schemaStr = new StringBuilder(); schemaStr.append("Normalized:Tuple("); for(ColumnConfig config: columnConfigList) { if(config.isMeta()) { schemaStr.append(config.getColumnName() + ":chararray" + ","); } else if(!config.isMeta() && config.isNumerical()) { schemaStr.append(config.getColumnName() + ":float" + ","); } else if(config.isTarget()) { schemaStr.append(config.getColumnName() + ":int" + ","); } else { if(config.isCategorical() && this.isForClean) { // clean data for DT algorithms, only store index, short is ok while Pig only have int type schemaStr.append(config.getColumnName() + ":chararray" + ","); } else { // for others, set to float, no matter LR/NN categorical or filter out feature with null schemaStr.append(config.getColumnName() + ":float" + ","); } } } schemaStr.append("weight:float)"); return Utils.getSchemaFromString(schemaStr.toString()); } catch (Exception e) { log.error("error in outputSchema", e); return null; } } /* * Create expressions for multi weight settings */ protected Map<Expression, Double> createExpressionMap(List<WeightAmplifier> weightExprList) { Map<Expression, Double> ewMap = new HashMap<Expression, Double>(); if(CollectionUtils.isNotEmpty(weightExprList)) { JexlEngine jexl = new JexlEngine(); for(WeightAmplifier we: weightExprList) { ewMap.put(jexl.createExpression(we.getTargetExpression()), we.getTargetWeight()); } } return ewMap; } /* * Create the expression for weight setting */ private Expression createExpression(String weightAmplifier) { if(StringUtils.isNotBlank(weightAmplifier)) { JexlEngine jexl = new JexlEngine(); return jexl.createExpression(weightAmplifier); } return null; } }