/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.varselect;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import ml.shifu.guagua.util.NumberFormatUtils;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.container.obj.RawSourceData.SourceType;
import ml.shifu.shifu.core.dtrain.CommonConstants;
import ml.shifu.shifu.core.dtrain.DTrainUtils;
import ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork;
import ml.shifu.shifu.util.CommonUtils;
import ml.shifu.shifu.util.Constants;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;
import org.encog.ml.MLRegression;
import org.encog.ml.data.basic.BasicMLData;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.base.Splitter;
/**
* Mapper implementation to accumulate MSE value when remove one column.
*
* <p>
* All the MSE values are accumulated in in-memory HashMap {@link #results}, which will also be write out in
* {@link #cleanup(org.apache.hadoop.mapreduce.Mapper.Context)}.
*
* <p>
* Output of all the mappers will be read and accumulated in VarSelectReducer to get all global MSE values. In Reducer,
* all MSE values sorted and select valid variables.
*
* @author Zhang David (pengzhang@paypal.com)
*/
public class VarSelectMapper extends Mapper<LongWritable, Text, LongWritable, ColumnInfo> {
private final static Logger LOG = LoggerFactory.getLogger(VarSelectMapper.class);
/**
* Default splitter used to split input record. Use one instance to prevent more news in Splitter.on.
*/
private static final Splitter DEFAULT_SPLITTER = Splitter.on(CommonConstants.DEFAULT_COLUMN_SEPARATOR);
/**
* Model Config read from HDFS
*/
private ModelConfig modelConfig;
/**
* Column Config list read from HDFS
*/
private List<ColumnConfig> columnConfigList;
/**
* Basic neural network model instance to compute basic score with all selected columns and wrapper selected
* columns.
*/
private MLRegression model;
/**
* Basic input node count for NN model, all the variables selected in current model training.
*/
private int inputNodeCount;
/**
* Final results map, this map is loaded in memory for sum, and will be written by context in cleanup.
*/
private Map<Long, ColumnInfo> results = new HashMap<Long, ColumnInfo>();
/**
* Inputs columns for each record. To save new objects in
* {@link #map(LongWritable, Text, org.apache.hadoop.mapreduce.Mapper.Context)}.
*/
private double[] inputs;
/**
* Outputs columns for each record. To save new objects in
* {@link #map(LongWritable, Text, org.apache.hadoop.mapreduce.Mapper.Context)}.
*/
private double[] outputs;
/**
* Column indexes for each record. To save new objects in
* {@link #map(LongWritable, Text, org.apache.hadoop.mapreduce.Mapper.Context)}.
*/
private long[] columnIndexes;
/**
* Input MLData instance to save new.
*/
private BasicMLData inputsMLData;
/**
* Prevent too many new objects for output key.
*/
private LongWritable outputKey;
/**
* Filter by sensitivity by target(ST) or sensitivity(SE).
*/
private String filterBy;
/**
* A counter to count # of records in current mapper.
*/
private long recordCount;
private Set<Integer> featureSet;
/**
* Load all configurations for modelConfig and columnConfigList from source type.
*/
private void loadConfigFiles(final Context context) {
try {
SourceType sourceType = SourceType.valueOf(context.getConfiguration().get(
Constants.SHIFU_MODELSET_SOURCE_TYPE, SourceType.HDFS.toString()));
this.modelConfig = CommonUtils.loadModelConfig(
context.getConfiguration().get(Constants.SHIFU_MODEL_CONFIG), sourceType);
this.columnConfigList = CommonUtils.loadColumnConfigList(
context.getConfiguration().get(Constants.SHIFU_COLUMN_CONFIG), sourceType);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* Load first model in model path as a {@link MLRegression} instance.
*/
private void loadModel() throws IOException {
this.model = (MLRegression) (CommonUtils.loadBasicModels(this.modelConfig, this.columnConfigList, null).get(0));
}
/**
* Do initialization like ModelConfig and ColumnConfig loading, model loading and others like input or output number
* loading.
*/
@Override
protected void setup(Context context) throws IOException, InterruptedException {
loadConfigFiles(context);
loadModel();
this.filterBy = context.getConfiguration()
.get(Constants.SHIFU_VARSELECT_FILTEROUT_TYPE, Constants.FILTER_BY_SE);
int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(this.columnConfigList);
this.inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0];
if(this.model instanceof BasicFloatNetwork) {
this.inputs = new double[((BasicFloatNetwork) this.model).getFeatureSet().size()];
this.featureSet = ((BasicFloatNetwork) this.model).getFeatureSet();
} else {
this.inputs = new double[this.inputNodeCount];
}
boolean isAfterVarSelect = (inputOutputIndex[0] != 0);
// cache all feature list for sampling features
if(this.featureSet == null || this.featureSet.size() == 0) {
this.featureSet = new HashSet<Integer>(CommonUtils.getAllFeatureList(columnConfigList, isAfterVarSelect));
this.inputs = new double[this.featureSet.size()];
}
this.outputs = new double[inputOutputIndex[1]];
this.columnIndexes = new long[this.inputs.length];
this.inputsMLData = new BasicMLData(this.inputs.length);
this.outputKey = new LongWritable();
LOG.info("Filter by is {}", filterBy);
}
@Override
protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
recordCount += 1L;
int index = 0, inputsIndex = 0, outputsIndex = 0;
for(String input: DEFAULT_SPLITTER.split(value.toString())) {
double doubleValue = NumberFormatUtils.getDouble(input.trim(), 0.0d);
if(index == this.columnConfigList.size()) {
break;
} else {
ColumnConfig columnConfig = this.columnConfigList.get(index);
if(columnConfig != null && columnConfig.isTarget()) {
this.outputs[outputsIndex++] = doubleValue;
} else {
if(this.featureSet.contains(columnConfig.getColumnNum())) {
inputs[inputsIndex] = doubleValue;
columnIndexes[inputsIndex++] = columnConfig.getColumnNum();
}
}
}
index++;
}
double oldValue = 0.0d;
this.inputsMLData.setData(this.inputs);
double candidateModelScore = 0d;
if(Constants.FILTER_BY_SE.equalsIgnoreCase(this.filterBy)) {
candidateModelScore = this.model.compute(new BasicMLData(inputs)).getData()[0];
}
for(int i = 0; i < this.inputs.length; i++) {
oldValue = this.inputs[i];
this.inputs[i] = 0d;
this.inputsMLData.setData(this.inputs);
double currentModelScore = this.model.compute(new BasicMLData(inputs)).getData()[0];
double diff = 0d;
if(Constants.FILTER_BY_ST.equalsIgnoreCase(this.filterBy)) {
// ST
diff = this.outputs[0] - currentModelScore;
} else {
// SE
diff = candidateModelScore - currentModelScore;
}
ColumnInfo columnInfo = this.results.get(this.columnIndexes[i]);
if(columnInfo == null) {
columnInfo = new ColumnInfo();
columnInfo.setSumScoreDiff(Math.abs(diff));
columnInfo.setSumSquareScoreDiff(power2(diff));
} else {
columnInfo.setSumScoreDiff(columnInfo.getSumScoreDiff() + Math.abs(diff));
columnInfo.setSumSquareScoreDiff(columnInfo.getSumSquareScoreDiff() + power2(diff));
}
this.results.put(this.columnIndexes[i], columnInfo);
this.inputs[i] = oldValue;
}
if(this.recordCount % 1000 == 0) {
LOG.info("Finish to process {} records.", this.recordCount);
}
}
/**
* Write all column->MSE pairs to output.
*/
@Override
protected void cleanup(Context context) throws IOException, InterruptedException {
for(Entry<Long, ColumnInfo> entry: results.entrySet()) {
this.outputKey.set(entry.getKey());
// value is sumValue, not sumValue/(number of records)
ColumnInfo columnInfo = entry.getValue();
columnInfo.setCount(this.recordCount);
context.write(this.outputKey, columnInfo);
}
LOG.debug("Final results: {}", results);
}
private double power2(double data) {
return data * data;
}
}