/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.varselect; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.List; import ml.shifu.shifu.container.obj.ColumnConfig; import ml.shifu.shifu.container.obj.RawSourceData.SourceType; import ml.shifu.shifu.core.dtrain.DTrainUtils; import ml.shifu.shifu.util.CommonUtils; import ml.shifu.shifu.util.Constants; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.Reducer; import org.apache.hadoop.mapreduce.lib.output.MultipleOutputs; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * {@link VarSelectReducer} is used to accumulate all mapper column-MSE values together. * * <p> * This is a global accumulation, reducer number in current MapReduce job should be set to 1. * * <p> * Input type is (ColumnId, Iterable(MSE)) from all mapper tasks. * * <p> * In {@link #cleanup(org.apache.hadoop.mapreduce.Reducer.Context)}, variables with MSE will be sorted according to * variable wrapper type. According to {@link #filterOutRatio} setting, only variables in that range will be written * into HDFS. * * <p> * {@link #filterOutRatio} means each time we need remove how many percentage of variables. A ratio is better than a * fixed number. Since each time we reduce variables which number is also decreased. Say 100 variables, wrapperRatio is * 0.05. First time we remove 100*0.05 = 5 variables, second time 95 * 0.05 variables will be removed. * * <p> * TODO Add mean value, not only MSE value; Write mean and MSE to files for later analysis. * * @author Zhang David (pengzhang@paypal.com) */ public class VarSelectReducer extends Reducer<LongWritable, ColumnInfo, Text, Text> { private final static Logger LOG = LoggerFactory.getLogger(VarSelectReducer.class); /** * Final results list, this list is loaded in memory for sum, and will be written by context in cleanup. */ private List<Pair> results = new ArrayList<Pair>(); /** * Column Config list read from HDFS */ private List<ColumnConfig> columnConfigList; /** * Basic input node count for NN model, all the variables selected in current model training. */ private int inputNodeCount; /** * To set as a ratio instead an absolute number, each time it is * a ratio. For example, 100 variables, using ratio 0.05, first time select 95 variables, next as candidates are * decreasing, next time it is still 0.05, but only 4 variables are removed. */ private float filterOutRatio; /** * Explicit set number of variables to be selected,this overwrites filterOutRatio */ private int filterNum; /** * Prevent too many new objects for output key. */ private Text outputKey; /** * Prevent too many new objects for output key. */ private Text outputValue; /** * Output value text. */ private final static Text OUTPUT_VALUE = new Text(""); /** * Wrapper by sensitivity by target(ST) or sensitivity(SE). */ private String filterBy; /** * Multiple outputs to write se report in HDFS. */ private MultipleOutputs<Text, Text> mos; /** * Load all configurations for modelConfig and columnConfigList from source type. */ private void loadConfigFiles(final Context context) { try { SourceType sourceType = SourceType.valueOf(context.getConfiguration().get( Constants.SHIFU_MODELSET_SOURCE_TYPE, SourceType.HDFS.toString())); this.columnConfigList = CommonUtils.loadColumnConfigList( context.getConfiguration().get(Constants.SHIFU_COLUMN_CONFIG), sourceType); } catch (IOException e) { throw new RuntimeException(e); } } /** * Do initialization like ModelConfig and ColumnConfig loading. */ @Override protected void setup(Context context) throws IOException, InterruptedException { loadConfigFiles(context); int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(this.columnConfigList); this.inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0]; this.filterOutRatio = context.getConfiguration().getFloat(Constants.SHIFU_VARSELECT_FILTEROUT_RATIO, Constants.SHIFU_DEFAULT_VARSELECT_FILTEROUT_RATIO); this.filterNum = context.getConfiguration().getInt(Constants.SHIFU_VARSELECT_FILTER_NUM, Constants.SHIFU_DEFAULT_VARSELECT_FILTER_NUM); this.outputKey = new Text(); this.outputValue = new Text(); this.filterBy = context.getConfiguration() .get(Constants.SHIFU_VARSELECT_FILTEROUT_TYPE, Constants.FILTER_BY_SE); this.mos = new MultipleOutputs<Text, Text>(context); LOG.info("FilterBy is {}, filterOutRatio is {}, filterNum is {}", filterBy, filterOutRatio, filterNum); } @Override protected void reduce(LongWritable key, Iterable<ColumnInfo> values, Context context) throws IOException, InterruptedException { ColumnStatistics column = new ColumnStatistics(); double sum = 0d; double sumSquare = 0d; long count = 0L; for(ColumnInfo info: values) { sum += info.getSumScoreDiff(); sumSquare += info.getSumSquareScoreDiff(); count += info.getCount(); } column.setMean(sum / count); column.setRms(Math.sqrt(sumSquare / count)); column.setVariance((sumSquare / count) - power2(sum / count)); this.results.add(new Pair(key.get(), column)); } private double power2(double data) { return data * data; } @Override protected void cleanup(Context context) throws IOException, InterruptedException { Collections.sort(this.results, new Comparator<Pair>() { @Override public int compare(Pair o1, Pair o2) { return Double.compare(o2.value.getRms(), o1.value.getRms()); } }); LOG.debug("Final Results:{}", this.results); int candidates = this.filterNum; if(candidates <= 0) { if(Constants.FILTER_BY_ST.equalsIgnoreCase(this.filterBy) || Constants.FILTER_BY_SE.equalsIgnoreCase(this.filterBy)) { candidates = (int) (this.inputNodeCount * (1.0f - this.filterOutRatio)); } else { // wrapper by A candidates = (int) (this.inputNodeCount * (this.filterOutRatio)); } } LOG.info("Candidates count is {}", candidates); for(int i = 0; i < this.results.size(); i++) { Pair pair = this.results.get(i); this.outputKey.set(pair.key + ""); if(i < candidates) { context.write(this.outputKey, OUTPUT_VALUE); } // for thousands of features, here using 'new' ok StringBuilder sb = new StringBuilder(100); sb.append(this.columnConfigList.get((int) pair.key).getColumnName()).append("\t") .append(pair.value.getMean()).append("\t").append(pair.value.getRms()).append("\t") .append(pair.value.getVariance()); this.outputValue.set(sb.toString()); this.mos.write(Constants.SHIFU_VARSELECT_SE_OUTPUT_NAME, this.outputKey, this.outputValue); } this.mos.close(); } private static class Pair { public Pair(long key, ColumnStatistics value) { this.key = key; this.value = value; } public long key; public ColumnStatistics value; @Override public String toString() { return key + ":" + value; } } }