/*
* Copyright [2013-2015] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.posttrain;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.TaskInputOutputContext;
import org.apache.hadoop.mapreduce.lib.output.MultipleOutputs;
import org.encog.ml.BasicML;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import ml.shifu.shifu.container.CaseScoreResult;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.container.obj.RawSourceData.SourceType;
import ml.shifu.shifu.core.DataPurifier;
import ml.shifu.shifu.core.ModelRunner;
import ml.shifu.shifu.core.posttrain.FeatureStatsWritable.BinStats;
import ml.shifu.shifu.util.CommonUtils;
import ml.shifu.shifu.util.Constants;
/**
* {@link PostTrainMapper} is mapper to improve original post train efficiency.
*
* <p>
* In Mapper, scan all normalized input data, then compute sum score in each bin of each variable.
*
* <p>
* Besides variable average score computing, score related info are stored to HDFS by using {@link MultipleOutputs}.
*
* @author Zhang David (pengzhang@paypal.com)
*/
public class PostTrainMapper extends Mapper<LongWritable, Text, IntWritable, FeatureStatsWritable> {
private final static Logger LOG = LoggerFactory.getLogger(PostTrainMapper.class);
/**
* Model Config read from HDFS
*/
private ModelConfig modelConfig;
/**
* To filter records by customized expressions
*/
private DataPurifier dataPurifier;
/**
* Output key cache to avoid new operation.
*/
private IntWritable outputKey;
/**
* Tag column index
*/
private int tagColumnNum = -1;
/**
* Column Config list read from HDFS
*/
private List<ColumnConfig> columnConfigList;
// cache tags in set for search
private Set<String> tags;
private String[] headers;
private ModelRunner modelRunner;
private MultipleOutputs<NullWritable, Text> mos;
/**
* Prevent too many new objects for output key.
*/
private Text outputValue;
private Map<Integer, List<BinStats>> variableStatsMap;
private void loadConfigFiles(final Context context) {
try {
SourceType sourceType = SourceType.valueOf(context.getConfiguration().get(
Constants.SHIFU_MODELSET_SOURCE_TYPE, SourceType.HDFS.toString()));
this.modelConfig = CommonUtils.loadModelConfig(
context.getConfiguration().get(Constants.SHIFU_MODEL_CONFIG), sourceType);
this.columnConfigList = CommonUtils.loadColumnConfigList(
context.getConfiguration().get(Constants.SHIFU_COLUMN_CONFIG), sourceType);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@SuppressWarnings({ "rawtypes", "unchecked" })
@Override
protected void setup(Context context) throws IOException, InterruptedException {
loadConfigFiles(context);
loadTagWeightNum();
this.dataPurifier = new DataPurifier(this.modelConfig);
this.outputKey = new IntWritable();
this.outputValue = new Text();
this.tags = new HashSet<String>(modelConfig.getFlattenTags());
SourceType sourceType = this.modelConfig.getDataSet().getSource();
List<BasicML> models = CommonUtils.loadBasicModels(modelConfig, null, sourceType);
this.headers = CommonUtils.getFinalHeaders(modelConfig);
this.modelRunner = new ModelRunner(modelConfig, columnConfigList, this.headers,
modelConfig.getDataSetDelimiter(), models);
this.mos = new MultipleOutputs<NullWritable, Text>((TaskInputOutputContext) context);
this.initFeatureStats();
}
/**
* Load tag weight index field.
*/
private void loadTagWeightNum() {
for(ColumnConfig config: this.columnConfigList) {
if(config.isTarget()) {
this.tagColumnNum = config.getColumnNum();
break;
}
}
if(this.tagColumnNum == -1) {
throw new RuntimeException("No valid target column.");
}
}
private void initFeatureStats() {
this.variableStatsMap = new HashMap<Integer, List<BinStats>>();
for(ColumnConfig config: this.columnConfigList) {
if(!config.isMeta() && !config.isTarget() && config.isFinalSelect()) {
List<BinStats> feaureStatistics = null;
int binSize = 0;
if(config.isNumerical()) {
binSize = config.getBinBoundary().size() + 1;
}
if(config.isCategorical()) {
binSize = config.getBinCategory().size();
}
feaureStatistics = new ArrayList<BinStats>(binSize);
for(int i = 0; i < binSize; i++) {
feaureStatistics.add(new BinStats(0, 0));
}
this.variableStatsMap.put(config.getColumnNum(), feaureStatistics);
}
}
}
@Override
protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
String valueStr = value.toString();
// StringUtils.isBlank is not used here to avoid import new jar
if(valueStr == null || valueStr.length() == 0 || valueStr.trim().length() == 0) {
LOG.warn("Empty input.");
return;
}
if(!this.dataPurifier.isFilterOut(valueStr)) {
return;
}
String[] units = CommonUtils.split(valueStr, this.modelConfig.getDataSetDelimiter());
// tagColumnNum should be in units array, if not IndexOutofBoundException
String tag = CommonUtils.trimTag(units[this.tagColumnNum]);
if(!this.tags.contains(tag)) {
if(System.currentTimeMillis() % 20 == 0) {
LOG.warn("Data with invalid tag is ignored in post train, invalid tag: {}.", tag);
}
context.getCounter(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG").increment(1L);
return;
}
Map<String, String> rawDataMap = buildRawDataMap(units);
CaseScoreResult csr = this.modelRunner.compute(rawDataMap);
// store score value
StringBuilder sb = new StringBuilder(500);
sb.append(csr.getAvgScore()).append(Constants.DEFAULT_DELIMITER).append(csr.getMaxScore())
.append(Constants.DEFAULT_DELIMITER).append(csr.getMinScore()).append(Constants.DEFAULT_DELIMITER);
for(Double score: csr.getScores()) {
sb.append(score).append(Constants.DEFAULT_DELIMITER);
}
List<String> metaList = modelConfig.getMetaColumnNames();
for(String meta: metaList) {
sb.append(rawDataMap.get(meta)).append(Constants.DEFAULT_DELIMITER);
}
sb.deleteCharAt(sb.length() - Constants.DEFAULT_DELIMITER.length());
this.outputValue.set(sb.toString());
this.mos.write(Constants.POST_TRAIN_OUTPUT_SCORE, NullWritable.get(), this.outputValue);
for(int i = 0; i < headers.length; i++) {
ColumnConfig config = this.columnConfigList.get(i);
if(!config.isMeta() && !config.isTarget() && config.isFinalSelect()) {
int binNum = CommonUtils.getBinNum(config, units[i]);
List<BinStats> feaureStatistics = this.variableStatsMap.get(config.getColumnNum());
BinStats bs = null;
if(binNum == -1) {
// if -1, means invalid numeric value like null or empty, last one is for empty stats.
bs = feaureStatistics.get(feaureStatistics.size() - 1);
} else {
bs = feaureStatistics.get(binNum);
}
// bs should not be null as already initialized in setup
bs.setBinSum(csr.getAvgScore() + bs.getBinSum());
bs.setBinCnt(1L + bs.getBinCnt());
}
}
}
private Map<String, String> buildRawDataMap(String[] units) {
Map<String, String> rawDataMap = new HashMap<String, String>(headers.length, 1f);
for(int i = 0; i < headers.length; i++) {
if(units[i] == null) {
rawDataMap.put(headers[i], "");
} else {
rawDataMap.put(headers[i], units[i].toString());
}
}
return rawDataMap;
}
@Override
protected void cleanup(Context context) throws IOException, InterruptedException {
for(Entry<Integer, List<BinStats>> entry: this.variableStatsMap.entrySet()) {
this.outputKey.set(entry.getKey());
context.write(this.outputKey, new FeatureStatsWritable(entry.getValue()));
}
}
}