/* * Copyright [2012-2014] PayPal Software Foundation * <p/> * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * <p/> * http://www.apache.org/licenses/LICENSE-2.0 * <p/> * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.udf; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelStatsConf; import ml.shifu.shifu.container.obj.ModelStatsConf.BinningMethod; import ml.shifu.shifu.exception.ShifuErrorCode; import ml.shifu.shifu.exception.ShifuException; import ml.shifu.shifu.util.CommonUtils; import ml.shifu.shifu.util.Constants; import org.apache.pig.data.*; import org.apache.pig.impl.logicalLayer.schema.Schema; import org.apache.pig.impl.logicalLayer.schema.Schema.FieldSchema; import org.apache.pig.tools.pigstats.PigStatusReporter; import java.io.IOException; /** * <pre> * AddColumnNumUDF class is to convert tuple of row data into bag of column data * Its structure is like * { * (column-id, column-value, column-tag, column-score) * (column-id, column-value, column-tag, column-score) * ... * } * </pre> */ public class AddColumnNumAndFilterUDF extends AddColumnNumUDF { private final boolean isAppendRandom; public AddColumnNumAndFilterUDF(String source, String pathModelConfig, String pathColumnConfig, String withScoreStr) throws Exception { this(source, pathModelConfig, pathColumnConfig, withScoreStr, "true"); } public AddColumnNumAndFilterUDF(String source, String pathModelConfig, String pathColumnConfig, String withScoreStr, String isAppendRandom) throws Exception { super(source, pathModelConfig, pathColumnConfig, withScoreStr); this.isAppendRandom = Boolean.TRUE.toString().equalsIgnoreCase(isAppendRandom); } @SuppressWarnings("deprecation") @Override public DataBag exec(Tuple input) throws IOException { DataBag bag = BagFactory.getInstance().newDefaultBag(); TupleFactory tupleFactory = TupleFactory.getInstance(); if(input == null) { return null; } int size = input.size(); if(size == 0 || input.size() != this.columnConfigList.size()) { log.info("the input size - " + input.size() + ", while column size - " + columnConfigList.size()); throw new ShifuException(ShifuErrorCode.ERROR_NO_EQUAL_COLCONFIG); } if(input.get(tagColumnNum) == null) { log.info("tagColumnNum is " + tagColumnNum + "; input size is " + input.size() + "; columnConfigList.size() is " + columnConfigList.size() + "; tuple is" + input.toDelimitedString("|") + "; tag is " + input.get(tagColumnNum)); if(isPigEnabled(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG")) { PigStatusReporter.getInstance().getCounter(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG").increment(1); } return null; } String tag = CommonUtils.trimTag(input.get(tagColumnNum).toString()); // filter out tag not in setting tagging list if(!super.tagSet.contains(tag)) { if(isPigEnabled(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG")) { PigStatusReporter.getInstance().getCounter(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG").increment(1); } return null; } Double rate = modelConfig.getBinningSampleRate(); if(modelConfig.isBinningSampleNegOnly()) { if(super.negTagSet.contains(tag) && random.nextDouble() > rate) { return null; } } else { if(random.nextDouble() > rate) { return null; } } for(int i = 0; i < size; i++) { ColumnConfig config = columnConfigList.get(i); // if (config.isCandidate()) { // all columns can be stats boolean isPositive = false; if(modelConfig.isRegression()) { if(super.posTagSet.contains(tag)) { isPositive = true; } else if(super.negTagSet.contains(tag)) { isPositive = false; } else { // not valid tag, just skip current record continue; } } if(!isValidRecord(modelConfig.isRegression(), isPositive, config)) { continue; } Tuple tuple = tupleFactory.newTuple(TOTAL_COLUMN_CNT); tuple.set(COLUMN_ID_INDX, i); // Set Data tuple.set(COLUMN_VAL_INDX, (input.get(i) == null ? null : input.get(i).toString())); if(modelConfig.isRegression()) { // Set Tag if(super.posTagSet.contains(tag)) { tuple.set(COLUMN_TAG_INDX, true); } if(super.negTagSet.contains(tag)) { tuple.set(COLUMN_TAG_INDX, false); } } else { // a mock for multiple classification tuple.set(COLUMN_TAG_INDX, true); } // get weight value tuple.set(COLUMN_WEIGHT_INDX, getWeightColumnVal(input)); // add random seed for distribution for bigger mapper, 300 is not enough TODO if(this.isAppendRandom) { tuple.set(COLUMN_SEED_INDX, Math.abs(random.nextInt() % 300)); } bag.add(tuple); } // } return bag; } @Override public Schema outputSchema(Schema input) { try { Schema tupleSchema = new Schema(); tupleSchema.add(new FieldSchema("columnId", DataType.INTEGER)); tupleSchema.add(new FieldSchema("value", DataType.CHARARRAY)); tupleSchema.add(new FieldSchema("tag", DataType.BOOLEAN)); if(this.isAppendRandom) { tupleSchema.add(new FieldSchema("rand", DataType.INTEGER)); } tupleSchema.add(new FieldSchema("weight", DataType.DOUBLE)); return new Schema(new Schema.FieldSchema("columnInfos", new Schema(new Schema.FieldSchema("columnInfo", tupleSchema, DataType.TUPLE)), DataType.BAG)); } catch (IOException e) { log.error("Error in outputSchema", e); return null; } } private boolean isValidRecord(boolean isBinary, boolean isPositive, ColumnConfig columnConfig) { if(isBinary) { return columnConfig != null && (columnConfig.isCategorical() || isValidBinningMethodForBinary(isPositive)); } else { return columnConfig != null && (columnConfig.isCategorical() || isValidBinningMethod()); } } private boolean isValidBinningMethodForBinary(boolean isPositive) { return modelConfig.getBinningAlgorithm().equals(ModelStatsConf.BinningAlgorithm.DynamicBinning) || modelConfig.getBinningMethod().equals(BinningMethod.EqualTotal) || modelConfig.getBinningMethod().equals(BinningMethod.EqualInterval) || (modelConfig.getBinningMethod().equals(BinningMethod.EqualPositive) && isPositive) || (modelConfig.getBinningMethod().equals(BinningMethod.EqualNegtive) && !isPositive) || modelConfig.getBinningMethod().equals(BinningMethod.WeightEqualTotal) || modelConfig.getBinningMethod().equals(BinningMethod.WeightEqualInterval) || (modelConfig.getBinningMethod().equals(BinningMethod.WeightEqualPositive) && isPositive) || (modelConfig.getBinningMethod().equals(BinningMethod.WeightEqualNegative) && !isPositive); } private boolean isValidBinningMethod() { return modelConfig.getBinningMethod().equals(BinningMethod.EqualTotal) || modelConfig.getBinningMethod().equals(BinningMethod.EqualInterval) || modelConfig.getBinningMethod().equals(BinningMethod.WeightEqualTotal) || modelConfig.getBinningMethod().equals(BinningMethod.WeightEqualInterval); } }