/* * Copyright [2013-2015] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.posttrain; import org.apache.hadoop.io.DoubleWritable; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.Mapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.core.DataPurifier; import ml.shifu.shifu.util.CommonUtils; import ml.shifu.shifu.util.Constants; /** * {@link FeatureImportanceMapper} is to compute the most important variables in one model. * * <p> * Per each record, get the top 3 biggest variables in one bin. Then sent to reducer for further statistics. * * @author Zhang David (pengzhang@paypal.com) */ public class FeatureImportanceMapper extends Mapper<LongWritable, Text, IntWritable, DoubleWritable> { private final static Logger LOG = LoggerFactory.getLogger(FeatureImportanceMapper.class); /** * Model Config read from HDFS */ private ModelConfig modelConfig; /** * To filter records by customized expressions */ private DataPurifier dataPurifier; /** * Output key cache to avoid new operation. */ private IntWritable outputKey; /** * Tag column index */ private int tagColumnNum = -1; /** * Column Config list read from HDFS */ private List<ColumnConfig> columnConfigList; // cache tags in set for search private Set<String> tags; private String[] headers; /** * Prevent too many new objects for output key. */ private DoubleWritable outputValue; private Map<Integer, Double> variableStatsMap; private void loadConfigFiles(final Context context) { try { SourceType sourceType = SourceType.valueOf(context.getConfiguration().get( Constants.SHIFU_MODELSET_SOURCE_TYPE, SourceType.HDFS.toString())); this.modelConfig = CommonUtils.loadModelConfig( context.getConfiguration().get(Constants.SHIFU_MODEL_CONFIG), sourceType); this.columnConfigList = CommonUtils.loadColumnConfigList( context.getConfiguration().get(Constants.SHIFU_COLUMN_CONFIG), sourceType); } catch (IOException e) { throw new RuntimeException(e); } } @Override protected void setup(Context context) throws IOException, InterruptedException { loadConfigFiles(context); loadTagWeightNum(); this.dataPurifier = new DataPurifier(this.modelConfig); this.outputKey = new IntWritable(); this.outputValue = new DoubleWritable(); this.tags = new HashSet<String>(modelConfig.getFlattenTags()); this.headers = CommonUtils.getFinalHeaders(modelConfig); this.initFeatureStats(); } /** * Load tag weight index field. */ private void loadTagWeightNum() { for(ColumnConfig config: this.columnConfigList) { if(config.isTarget()) { this.tagColumnNum = config.getColumnNum(); break; } } if(this.tagColumnNum == -1) { throw new RuntimeException("No valid target column."); } } private void initFeatureStats() { this.variableStatsMap = new HashMap<Integer, Double>(); for(ColumnConfig config: this.columnConfigList) { if(!config.isMeta() && !config.isTarget() && config.isFinalSelect()) { this.variableStatsMap.put(config.getColumnNum(), 0d); } } } public static class FeatureScore { public FeatureScore(int columnNum, int binAvgScore) { super(); this.columnNum = columnNum; this.binAvgScore = binAvgScore; } private int columnNum; private int binAvgScore; } @Override protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException { String valueStr = value.toString(); // StringUtils.isBlank is not used here to avoid import new jar if(valueStr == null || valueStr.length() == 0 || valueStr.trim().length() == 0) { LOG.warn("Empty input."); return; } if(!this.dataPurifier.isFilterOut(valueStr)) { return; } String[] units = CommonUtils.split(valueStr, this.modelConfig.getDataSetDelimiter()); // tagColumnNum should be in units array, if not IndexOutofBoundException String tag = CommonUtils.trimTag(units[this.tagColumnNum]); if(!this.tags.contains(tag)) { if(System.currentTimeMillis() % 20 == 0) { LOG.warn("Data with invalid tag is ignored in post train, invalid tag: {}.", tag); } context.getCounter(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG").increment(1L); return; } List<FeatureScore> featureScores = new ArrayList<FeatureImportanceMapper.FeatureScore>(); for(int i = 0; i < headers.length; i++) { ColumnConfig config = this.columnConfigList.get(i); if(!config.isMeta() && !config.isTarget() && config.isFinalSelect()) { int binNum = CommonUtils.getBinNum(config, units[i]); List<Integer> binAvgScores = config.getBinAvgScore(); int binScore = 0; if(binNum == -1) { binScore = binAvgScores.get(binAvgScores.size() - 1); } else { binScore = binAvgScores.get(binNum); } featureScores.add(new FeatureScore(config.getColumnNum(), binScore)); } } Collections.sort(featureScores, new Comparator<FeatureScore>() { @Override public int compare(FeatureScore fs1, FeatureScore fs2) { if(fs1.binAvgScore < fs2.binAvgScore) { return 1; } if(fs1.binAvgScore > fs2.binAvgScore) { return -1; } return 0; } }); int size = featureScores.size() >= 3 ? 3 : featureScores.size(); for(int i = 0; i < size; i++) { FeatureScore featureScore = featureScores.get(i); Double currValue = this.variableStatsMap.get(featureScore.columnNum); currValue += size - i; this.variableStatsMap.put(featureScore.columnNum, currValue); } } @Override protected void cleanup(Context context) throws IOException, InterruptedException { for(Entry<Integer, Double> entry: this.variableStatsMap.entrySet()) { this.outputKey.set(entry.getKey()); this.outputValue.set(entry.getValue()); context.write(this.outputKey, this.outputValue); } } }