/*
* Copyright [2012-2015] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.binning;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.container.obj.RawSourceData.SourceType;
import ml.shifu.shifu.core.DataPurifier;
import ml.shifu.shifu.core.autotype.AutoTypeDistinctCountMapper.CountAndFrequentItems;
import ml.shifu.shifu.core.autotype.CountAndFrequentItemsWritable;
import ml.shifu.shifu.util.CommonUtils;
import ml.shifu.shifu.util.Constants;
import org.apache.commons.lang.StringUtils;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.base.Splitter;
/**
* {@link UpdateBinningInfoMapper} is a mapper to update local data statistics given bin boundary list.
*
* <p>
* Bin boundary list is got by using distributed cache. After read bin boundary list, by iterate each record, to update
* count and weighted value in each bin.
*
* <p>
* This map-reduce job is to solve issue in group by all data together per each column in pig version. It is job with
* mappers and one reducer. The scalability is very good.
*
* <p>
* We assume that all column info can be saved in mapper memory.
*
* <p>
* 'median' can not be computed through such distributed solution.
*/
public class UpdateBinningInfoMapper extends Mapper<LongWritable, Text, IntWritable, BinningInfoWritable> {
private final static Logger LOG = LoggerFactory.getLogger(UpdateBinningInfoMapper.class);
/**
* Default splitter used to split input record. Use one instance to prevent more news in Splitter.on.
*/
private String dataSetDelimiter;
/**
* Model Config read from HDFS
*/
private ModelConfig modelConfig;
/**
* To filter records by customized expressions
*/
private DataPurifier dataPurifier;
/**
* Weight column index.
*/
private int weightedColumnNum = -1;
/**
* Column Config list read from HDFS
*/
private List<ColumnConfig> columnConfigList;
/**
* Tag column index
*/
private int tagColumnNum = -1;
/**
* A map to store all statistics for all columns.
*/
private Map<Integer, BinningInfoWritable> columnBinningInfo;
/**
* Bin boundary list splitter.
*/
private static Splitter BIN_BOUNDARY_SPLITTER = Splitter.on(Constants.BIN_BOUNDRY_DELIMITER).trimResults();
/**
* Output key cache to avoid new operation.
*/
private IntWritable outputKey;
/**
* TODO At risk to be a big memory cost OOM.
*/
private Map<Integer, Map<String, Integer>> categoricalBinMap;
/**
* Using approximate method to estimate real frequent items and store into this map
*/
private Map<Integer, CountAndFrequentItems> variableCountMap;
// cache tags in set for search
private Set<String> posTags;
private Set<String> negTags;
private Set<String> tags;
private Set<String> missingOrInvalidValues;
private int weightExceptions = 0;
private boolean isThrowforWeightException;
/**
* Load model config and column config files.
*/
private void loadConfigFiles(final Context context) {
try {
SourceType sourceType = SourceType.valueOf(context.getConfiguration().get(
Constants.SHIFU_MODELSET_SOURCE_TYPE, SourceType.HDFS.toString()));
this.modelConfig = CommonUtils.loadModelConfig(
context.getConfiguration().get(Constants.SHIFU_MODEL_CONFIG), sourceType);
this.columnConfigList = CommonUtils.loadColumnConfigList(
context.getConfiguration().get(Constants.SHIFU_COLUMN_CONFIG), sourceType);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* Initialization for column statistics in mapper.
*/
@Override
protected void setup(Context context) throws IOException, InterruptedException {
loadConfigFiles(context);
this.dataSetDelimiter = this.modelConfig.getDataSetDelimiter();
this.dataPurifier = new DataPurifier(this.modelConfig);
loadWeightColumnNum();
loadTagWeightNum();
this.columnBinningInfo = new HashMap<Integer, BinningInfoWritable>(this.columnConfigList.size() * 4 / 3);
this.categoricalBinMap = new HashMap<Integer, Map<String, Integer>>(this.columnConfigList.size() * 4 / 3);
loadColumnBinningInfo();
this.outputKey = new IntWritable();
this.variableCountMap = new HashMap<Integer, CountAndFrequentItems>();
this.posTags = new HashSet<String>(modelConfig.getPosTags());
this.negTags = new HashSet<String>(modelConfig.getNegTags());
this.tags = new HashSet<String>(modelConfig.getFlattenTags());
this.missingOrInvalidValues = new HashSet<String>(this.modelConfig.getDataSet().getMissingOrInvalidValues());
this.isThrowforWeightException = "true".equalsIgnoreCase(context.getConfiguration().get(
"shifu.weight.exception", "false"));
LOG.debug("Column binning info: {}", this.columnBinningInfo);
}
/**
* Load and initialize column binning info object.
*/
private void loadColumnBinningInfo() throws FileNotFoundException, IOException {
BufferedReader reader = null;
try {
reader = new BufferedReader(new InputStreamReader(new FileInputStream(Constants.BINNING_INFO_FILE_NAME),
Charset.forName("UTF-8")));
String line = reader.readLine();
while(line != null && line.length() != 0) {
LOG.debug("line is {}", line);
// here just use String.split for just two columns
String[] cols = CommonUtils.split(line.trim(), Constants.DEFAULT_DELIMITER);
if(cols != null && cols.length >= 2) {
Integer columnNum = Integer.parseInt(cols[0]);
BinningInfoWritable binningInfo = new BinningInfoWritable();
binningInfo.setColumnNum(columnNum);
ColumnConfig columnConfig = this.columnConfigList.get(columnNum);
int binSize = 0;
if(columnConfig.isNumerical()) {
binningInfo.setNumeric(true);
List<Double> list = new ArrayList<Double>();
for(String startElement: BIN_BOUNDARY_SPLITTER.split(cols[1])) {
list.add(Double.valueOf(startElement));
}
binningInfo.setBinBoundaries(list);
binSize = list.size();
} else {
binningInfo.setNumeric(false);
List<String> list = new ArrayList<String>();
Map<String, Integer> map = this.categoricalBinMap.get(columnNum);
if(map == null) {
map = new HashMap<String, Integer>();
this.categoricalBinMap.put(columnNum, map);
}
int index = 0;
if(!StringUtils.isBlank(cols[1])) {
for(String startElement: BIN_BOUNDARY_SPLITTER.split(cols[1])) {
list.add(startElement);
map.put(startElement, index++);
}
}
binningInfo.setBinCategories(list);
binSize = list.size();
}
long[] binCountPos = new long[binSize + 1];
binningInfo.setBinCountPos(binCountPos);
long[] binCountNeg = new long[binSize + 1];
binningInfo.setBinCountNeg(binCountNeg);
double[] binWeightPos = new double[binSize + 1];
binningInfo.setBinWeightPos(binWeightPos);
double[] binWeightNeg = new double[binSize + 1];
binningInfo.setBinWeightNeg(binWeightNeg);
LOG.info("column num {} and info {}", columnNum, binningInfo);
this.columnBinningInfo.put(columnNum, binningInfo);
}
line = reader.readLine();
}
} finally {
if(reader != null) {
reader.close();
}
}
}
/**
* Load tag weight index field.
*/
private void loadTagWeightNum() {
for(ColumnConfig config: this.columnConfigList) {
if(config.isTarget()) {
this.tagColumnNum = config.getColumnNum();
break;
}
}
if(this.tagColumnNum == -1) {
throw new RuntimeException("No valid target column.");
}
}
/**
* Load weight column index field.
*/
private void loadWeightColumnNum() {
String weightColumnName = this.modelConfig.getDataSet().getWeightColumnName();
if(weightColumnName != null && weightColumnName.length() != 0) {
for(int i = 0; i < this.columnConfigList.size(); i++) {
ColumnConfig config = this.columnConfigList.get(i);
if(config.getColumnName().equals(weightColumnName)) {
this.weightedColumnNum = i;
break;
}
}
}
}
/**
* Mapper implementation includes: 1. Invalid data purifier 2. Column statistics update.
*/
@Override
protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
String valueStr = value.toString();
if(valueStr == null || valueStr.length() == 0 || valueStr.trim().length() == 0) {
LOG.warn("Empty input.");
return;
}
context.getCounter(Constants.SHIFU_GROUP_COUNTER, "TOTAL_VALID_COUNT").increment(1L);
if(!this.dataPurifier.isFilterOut(valueStr)) {
context.getCounter(Constants.SHIFU_GROUP_COUNTER, "FILTER_OUT_COUNT").increment(1L);
return;
}
String[] units = CommonUtils.split(valueStr, this.dataSetDelimiter);
// tagColumnNum should be in units array, if not IndexOutofBoundException
String tag = CommonUtils.trimTag(units[this.tagColumnNum]);
if(modelConfig.isRegression()) {
if(tag == null || (!posTags.contains(tag) && !negTags.contains(tag))) {
context.getCounter(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG").increment(1L);
return;
}
} else {
if(tag == null || (!tags.contains(tag))) {
context.getCounter(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG").increment(1L);
return;
}
}
Double weight = 1.0;
try {
weight = (this.weightedColumnNum == -1 ? 1.0d : Double.valueOf(units[this.weightedColumnNum]));
if(weight < 0) {
weightExceptions += 1;
context.getCounter(Constants.SHIFU_GROUP_COUNTER, "WEIGHT_EXCEPTION").increment(1L);
if(weightExceptions > 5000 && this.isThrowforWeightException) {
throw new IllegalStateException(
"Please check weight column in eval, exceptional weight count is over 5000");
}
}
} catch (NumberFormatException e) {
weightExceptions += 1;
context.getCounter(Constants.SHIFU_GROUP_COUNTER, "WEIGHT_EXCEPTION").increment(1L);
if(weightExceptions > 5000 && this.isThrowforWeightException) {
throw new IllegalStateException(
"Please check weight column in eval, exceptional weight count is over 5000");
}
}
// valid data process
boolean isMissingValue = false;
boolean isInvalidValue = false;
for(int i = 0; i < units.length; i++) {
ColumnConfig columnConfig = this.columnConfigList.get(i);
CountAndFrequentItems countAndFrequentItems = this.variableCountMap.get(i);
if(countAndFrequentItems == null) {
countAndFrequentItems = new CountAndFrequentItems();
this.variableCountMap.put(i, countAndFrequentItems);
}
countAndFrequentItems.offer(this.missingOrInvalidValues, units[i]);
// meta and target is not skipped, commeted out
// if(columnConfig.isMeta() || columnConfig.isTarget()) {
// continue;
// }
isMissingValue = false;
isInvalidValue = false;
BinningInfoWritable binningInfoWritable = this.columnBinningInfo.get(i);
if(binningInfoWritable == null) {
continue; // doesn't exist
}
binningInfoWritable.setTotalCount(binningInfoWritable.getTotalCount() + 1L);
if(columnConfig.isCategorical()) {
int lastBinIndex = binningInfoWritable.getBinCategories().size();
int binNum = 0;
if(units[i] == null || missingOrInvalidValues.contains(units[i].toLowerCase())) {
isMissingValue = true;
} else {
String str = StringUtils.trim(units[i]);
binNum = quickLocateCategoricalBin(this.categoricalBinMap.get(i), str);
if(binNum < 0) {
isInvalidValue = true;
}
}
if(isInvalidValue || isMissingValue) {
binningInfoWritable.setMissingCount(binningInfoWritable.getMissingCount() + 1L);
binNum = lastBinIndex;
}
if(modelConfig.isRegression()) {
if(posTags.contains(tag)) {
binningInfoWritable.getBinCountPos()[binNum] += 1L;
binningInfoWritable.getBinWeightPos()[binNum] += weight;
} else if(negTags.contains(tag)) {
binningInfoWritable.getBinCountNeg()[binNum] += 1L;
binningInfoWritable.getBinWeightNeg()[binNum] += weight;
}
} else {
// for multiple classification, set bin count to BinCountPos and leave BinCountNeg empty
binningInfoWritable.getBinCountPos()[binNum] += 1L;
binningInfoWritable.getBinWeightPos()[binNum] += weight;
}
} else if(columnConfig.isNumerical()) {
int lastBinIndex = binningInfoWritable.getBinBoundaries().size();
double douVal = 0.0;
if(units[i] == null || units[i].length() == 0) {
isMissingValue = true;
} else {
try {
douVal = Double.parseDouble(units[i].trim());
} catch (Exception e) {
isInvalidValue = true;
}
}
// add logic the same as CalculateNewStatsUDF
if(Double.compare(douVal, modelConfig.getNumericalValueThreshold()) > 0) {
isInvalidValue = true;
}
if(isInvalidValue || isMissingValue) {
binningInfoWritable.setMissingCount(binningInfoWritable.getMissingCount() + 1L);
if(modelConfig.isRegression()) {
if(posTags.contains(tag)) {
binningInfoWritable.getBinCountPos()[lastBinIndex] += 1L;
binningInfoWritable.getBinWeightPos()[lastBinIndex] += weight;
} else if(negTags.contains(tag)) {
binningInfoWritable.getBinCountNeg()[lastBinIndex] += 1L;
binningInfoWritable.getBinWeightNeg()[lastBinIndex] += weight;
}
}
} else {
// For invalid or missing values, no need update sum, squaredSum, max, min ...
int binNum = getBinNum(binningInfoWritable.getBinBoundaries(), units[i]);
if(binNum == -1) {
throw new RuntimeException("binNum should not be -1 to this step.");
}
if(modelConfig.isRegression()) {
if(posTags.contains(tag)) {
binningInfoWritable.getBinCountPos()[binNum] += 1L;
binningInfoWritable.getBinWeightPos()[binNum] += weight;
} else if(negTags.contains(tag)) {
binningInfoWritable.getBinCountNeg()[binNum] += 1L;
binningInfoWritable.getBinWeightNeg()[binNum] += weight;
}
}
binningInfoWritable.setSum(binningInfoWritable.getSum() + douVal);
double squaredVal = douVal * douVal;
binningInfoWritable.setSquaredSum(binningInfoWritable.getSquaredSum() + squaredVal);
binningInfoWritable.setTripleSum(binningInfoWritable.getTripleSum() + squaredVal * douVal);
binningInfoWritable.setQuarticSum(binningInfoWritable.getQuarticSum() + squaredVal * squaredVal);
if(Double.compare(binningInfoWritable.getMax(), douVal) < 0) {
binningInfoWritable.setMax(douVal);
}
if(Double.compare(binningInfoWritable.getMin(), douVal) > 0) {
binningInfoWritable.setMin(douVal);
}
}
}
}
}
public static int getBinNum(List<Double> binBoundaryList, String columnVal) {
if(StringUtils.isBlank(columnVal)) {
return -1;
}
double dval = 0.0;
try {
dval = Double.parseDouble(columnVal);
} catch (Exception e) {
return -1;
}
return CommonUtils.getBinIndex(binBoundaryList, dval);
}
private int quickLocateCategoricalBin(Map<String, Integer> map, String val) {
Integer binNum = map.get(val);
return ((binNum == null) ? -1 : binNum);
}
/**
* Write column info to reducer for merging.
*/
@Override
protected void cleanup(Context context) throws IOException, InterruptedException {
LOG.debug("Column binning info: {}", this.columnBinningInfo);
LOG.debug("Column count info: {}", this.variableCountMap);
for(Map.Entry<Integer, BinningInfoWritable> entry: this.columnBinningInfo.entrySet()) {
CountAndFrequentItems cfi = this.variableCountMap.get(entry.getKey());
if(cfi != null) {
entry.getValue().setCfiw(
new CountAndFrequentItemsWritable(cfi.getCount(), cfi.getInvalidCount(),
cfi.getValidNumCount(), cfi.getHyper().getBytes(), cfi.getFrequentItems()));
} else {
LOG.info("cci is null for column {}", entry.getKey());
}
this.outputKey.set(entry.getKey());
context.write(this.outputKey, entry.getValue());
}
}
}