/* * Copyright [2013-2017] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.udf; import java.io.BufferedWriter; import java.io.File; import java.io.IOException; import ml.shifu.shifu.container.obj.EvalConfig; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.fs.ShifuFileUtils; import ml.shifu.shifu.util.CommonUtils; import ml.shifu.shifu.util.Constants; import org.apache.commons.lang.StringUtils; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.pig.data.DataType; import org.apache.pig.data.Tuple; import org.apache.pig.data.TupleFactory; import org.apache.pig.impl.logicalLayer.schema.Schema; import org.apache.pig.impl.logicalLayer.schema.Schema.FieldSchema; import org.apache.pig.impl.util.UDFContext; import org.apache.pig.tools.pigstats.PigStatusReporter; /** * To project only useful columns used in eval sorting. Meta, target, weight and score columns should be included. */ public class ColumnProjector extends AbstractTrainerUDF<Tuple> { private EvalConfig evalConfig; private String scoreMetaColumn; private String[] headers; private int targetColumnIndex = -1; private int weightColumnIndex = -1; private int scoreMetaColumnIndex = -1; private double maxScore = Double.MIN_VALUE; private double minScore = Double.MAX_VALUE; /** * A simple weight exception validation: if over 5000 throw exceptions */ private int weightExceptions; public ColumnProjector(String source, String pathModelConfig, String pathColumnConfig) throws IOException { super(source, pathModelConfig, pathColumnConfig); } public ColumnProjector(String source, String pathModelConfig, String pathColumnConfig, String evalSetName, String columnName) throws IOException { super(source, pathModelConfig, pathColumnConfig); this.evalConfig = modelConfig.getEvalConfigByName(evalSetName); this.scoreMetaColumn = columnName; // create model runner if(StringUtils.isNotBlank(evalConfig.getDataSet().getHeaderPath())) { this.headers = CommonUtils.getHeaders(evalConfig.getDataSet().getHeaderPath(), evalConfig.getDataSet() .getHeaderDelimiter(), evalConfig.getDataSet().getSource()); } else { String delimiter = StringUtils.isBlank(evalConfig.getDataSet().getHeaderDelimiter()) ? evalConfig .getDataSet().getDataDelimiter() : evalConfig.getDataSet().getHeaderDelimiter(); String[] fields = CommonUtils.takeFirstLine(evalConfig.getDataSet().getDataPath(), delimiter, evalConfig .getDataSet().getSource()); if(StringUtils.join(fields, "").contains(modelConfig.getTargetColumnName())) { this.headers = new String[fields.length]; for(int i = 0; i < fields.length; i++) { this.headers[i] = CommonUtils.getRelativePigHeaderColumnName(fields[i]); } log.warn("No header path is provided, we will try to read first line and detect schema."); log.warn("Schema in ColumnConfig.json are named as first line of data set path."); } else { log.warn("No header path is provided, we will try to read first line and detect schema."); log.warn("Schema in ColumnConfig.json are named as index 0, 1, 2, 3 ..."); log.warn("Please make sure weight column and tag column are also taking index as name."); this.headers = new String[fields.length]; for(int i = 0; i < fields.length; i++) { this.headers[i] = i + ""; } } } for(int i = 0; i < this.headers.length; i++) { if(this.headers[i].equals(evalConfig.getDataSet().getTargetColumnName())) { this.targetColumnIndex = i; } if(this.headers[i].equals(this.scoreMetaColumn)) { this.scoreMetaColumnIndex = i; } if(StringUtils.isNotBlank(evalConfig.getDataSet().getWeightColumnName()) && this.headers[i].equals(evalConfig.getDataSet().getWeightColumnName())) { this.weightColumnIndex = i; } } } @SuppressWarnings("deprecation") @Override public Tuple exec(Tuple input) throws IOException { Tuple tuple = TupleFactory.getInstance().newTuple(3); String tag = input.get(targetColumnIndex).toString(); tuple.set(0, tag); double score = 0; try { score = Double.parseDouble(input.get(scoreMetaColumnIndex).toString()); } catch (Exception e) { if(isPigEnabled(Constants.SHIFU_GROUP_COUNTER, "BAD_META_SCORE")) { PigStatusReporter.getInstance().getCounter(Constants.SHIFU_GROUP_COUNTER, "BAD_META_SCORE") .increment(1); } } if(score > maxScore) { maxScore = score; } if(score < minScore) { minScore = score; } tuple.set(1, score); String weight = "1"; if(weightColumnIndex != -1) { weight = input.get(weightColumnIndex).toString(); } tuple.set(2, weight); incrementTagCounters(tag, weight); return tuple; } @SuppressWarnings("deprecation") private void incrementTagCounters(String tag, String weight) { if(tag == null || weight == null) { log.warn("tag is empty " + tag + " or weight is empty " + weight); return; } double dWeight = 1.0; if(StringUtils.isNotBlank(weight)) { try { dWeight = Double.parseDouble(weight); } catch (Exception e) { if(isPigEnabled(Constants.SHIFU_GROUP_COUNTER, "weight_exceptions")) { PigStatusReporter.getInstance().getCounter(Constants.SHIFU_GROUP_COUNTER, "weight_exceptions") .increment(1); } weightExceptions += 1; if(weightExceptions > 5000) { throw new IllegalStateException( "Please check weight column in eval, exceptional weight count is over 5000"); } } } long weightLong = (long) (dWeight * Constants.EVAL_COUNTER_WEIGHT_SCALE); if(isPigEnabled(Constants.SHIFU_GROUP_COUNTER, Constants.COUNTER_RECORDS)) { PigStatusReporter.getInstance().getCounter(Constants.SHIFU_GROUP_COUNTER, Constants.COUNTER_RECORDS) .increment(1); } if(posTagSet.contains(tag)) { if(isPigEnabled(Constants.SHIFU_GROUP_COUNTER, Constants.COUNTER_POSTAGS)) { PigStatusReporter.getInstance().getCounter(Constants.SHIFU_GROUP_COUNTER, Constants.COUNTER_POSTAGS) .increment(1); } if(isPigEnabled(Constants.SHIFU_GROUP_COUNTER, Constants.COUNTER_WPOSTAGS)) { PigStatusReporter.getInstance().getCounter(Constants.SHIFU_GROUP_COUNTER, Constants.COUNTER_WPOSTAGS) .increment(weightLong); } } if(negTagSet.contains(tag)) { if(isPigEnabled(Constants.SHIFU_GROUP_COUNTER, Constants.COUNTER_NEGTAGS)) { PigStatusReporter.getInstance().getCounter(Constants.SHIFU_GROUP_COUNTER, Constants.COUNTER_NEGTAGS) .increment(1); } if(isPigEnabled(Constants.SHIFU_GROUP_COUNTER, Constants.COUNTER_WNEGTAGS)) { PigStatusReporter.getInstance().getCounter(Constants.SHIFU_GROUP_COUNTER, Constants.COUNTER_WNEGTAGS) .increment(weightLong); } } } @Override public void finish() { if(modelConfig.isClassification()) { return; } // only for regression, in some cases like gbdt, it's regression score is not in [0,1], to do eval performance, // max and min score should be collected to set bounds. BufferedWriter writer = null; Configuration jobConf = UDFContext.getUDFContext().getJobConf(); String scoreOutput = jobConf.get(Constants.SHIFU_EVAL_MAXMIN_SCORE_OUTPUT); log.debug("shifu.eval.maxmin.score.output is {}, job id is {}, task id is {}, attempt id is {}" + scoreOutput + " " + jobConf.get("mapreduce.job.id") + " " + jobConf.get("mapreduce.task.id") + " " + jobConf.get("mapreduce.task.partition") + " " + jobConf.get("mapreduce.task.attempt.id")); try { FileSystem fileSystem = FileSystem.get(jobConf); fileSystem.mkdirs(new Path(scoreOutput)); String taskMaxMinScoreFile = scoreOutput + File.separator + "part-" + jobConf.get("mapreduce.task.attempt.id"); writer = ShifuFileUtils.getWriter(taskMaxMinScoreFile, SourceType.HDFS); writer.write(maxScore + "," + minScore); } catch (IOException e) { log.error("error in finish", e); } finally { if(writer != null) { try { writer.close(); } catch (IOException ignore) { } } } } @Override public Schema outputSchema(Schema input) { try { Schema tupleSchema = new Schema(); tupleSchema.add(new FieldSchema("target", DataType.CHARARRAY)); tupleSchema.add(new FieldSchema(scoreMetaColumn, DataType.DOUBLE)); tupleSchema.add(new FieldSchema("weight", DataType.CHARARRAY)); return new Schema(new Schema.FieldSchema("score", tupleSchema, DataType.TUPLE)); } catch (IOException e) { log.error("Error in outputSchema", e); return null; } } }