/*
* Copyright [2013-2016] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.correlation;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import ml.shifu.guagua.util.NumberFormatUtils;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ColumnConfig.ColumnFlag;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.container.obj.RawSourceData.SourceType;
import ml.shifu.shifu.core.DataPurifier;
import ml.shifu.shifu.util.CommonUtils;
import ml.shifu.shifu.util.Constants;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* {@link CorrelationMapper} is used to compute {@link CorrelationWritable} per column per mapper.
*
* <p>
* Such {@link CorrelationWritable} is sent to reducer (only one) to merge and compute real pearson value.
*
* @author Zhang David (pengzhang@paypal.com)
*/
public class CorrelationMapper extends Mapper<LongWritable, Text, IntWritable, CorrelationWritable> {
private final static Logger LOG = LoggerFactory.getLogger(CorrelationMapper.class);
/**
* Default splitter used to split input record. Use one instance to prevent more news in Splitter.on.
*/
private String dataSetDelimiter;
/**
* Model Config read from HDFS
*/
private ModelConfig modelConfig;
/**
* To filter records by customized expressions
*/
private DataPurifier dataPurifier;
/**
* Count in current mapper
*/
private long count;
/**
* Column Config list read from HDFS
*/
private List<ColumnConfig> columnConfigList;
/**
* For categorical feature, a map is used to save query time in execution
*/
private Map<Integer, Map<String, Integer>> categoricalIndexMap = new HashMap<Integer, Map<String, Integer>>();
/**
* If compute all pairs (i, j), if false, only computes pairs (i, j) when i >= j
*/
private boolean isComputeAll = false;
// cache tags in set for search
protected Set<String> posTagSet;
protected Set<String> negTagSet;
protected Set<String> tagSet;
private List<Set<String>> tags;
private void loadConfigFiles(final Context context) {
try {
SourceType sourceType = SourceType.valueOf(context.getConfiguration().get(
Constants.SHIFU_MODELSET_SOURCE_TYPE, SourceType.HDFS.toString()));
this.modelConfig = CommonUtils.loadModelConfig(
context.getConfiguration().get(Constants.SHIFU_MODEL_CONFIG), sourceType);
this.columnConfigList = CommonUtils.loadColumnConfigList(
context.getConfiguration().get(Constants.SHIFU_COLUMN_CONFIG), sourceType);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override
protected void setup(Context context) throws IOException, InterruptedException {
loadConfigFiles(context);
this.dataSetDelimiter = this.modelConfig.getDataSetDelimiter();
this.dataPurifier = new DataPurifier(this.modelConfig);
this.isComputeAll = Boolean.valueOf(context.getConfiguration().get(Constants.SHIFU_CORRELATION_COMPUTE_ALL,
"false"));
for(ColumnConfig config: columnConfigList) {
if(config.isCategorical()) {
Map<String, Integer> map = new HashMap<String, Integer>();
if(config.getBinCategory() != null) {
for(int i = 0; i < config.getBinCategory().size(); i++) {
List<String> cvals = CommonUtils.flattenCatValGrp(config.getBinCategory().get(i));
for ( String cval : cvals ) {
map.put(cval, i);
}
}
}
this.categoricalIndexMap.put(config.getColumnNum(), map);
}
}
if(modelConfig != null && modelConfig.getPosTags() != null) {
this.posTagSet = new HashSet<String>(modelConfig.getPosTags());
}
if(modelConfig != null && modelConfig.getNegTags() != null) {
this.negTagSet = new HashSet<String>(modelConfig.getNegTags());
}
if(modelConfig != null && modelConfig.getFlattenTags() != null) {
this.tagSet = new HashSet<String>(modelConfig.getFlattenTags());
}
this.tags = this.modelConfig.getSetTags();
}
@Override
protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
String valueStr = value.toString();
if(valueStr == null || valueStr.length() == 0 || valueStr.trim().length() == 0) {
LOG.warn("Empty input.");
return;
}
double[] dValues = null;
if(!this.dataPurifier.isFilterOut(valueStr)) {
return;
}
context.getCounter(Constants.SHIFU_GROUP_COUNTER, "CNT_AFTER_FILTER").increment(1L);
// make sampling work in correlation
if(Math.random() >= this.modelConfig.getStats().getSampleRate()) {
return;
}
context.getCounter(Constants.SHIFU_GROUP_COUNTER, "CORRELATION_CNT").increment(1L);
dValues = getDoubleArrayByRawArray(CommonUtils.split(valueStr, this.dataSetDelimiter));
count += 1L;
if(count % 2000L == 0) {
LOG.info("Current records: {} in thread {}.", count, Thread.currentThread().getName());
}
for(int i = 0; i < this.columnConfigList.size(); i++) {
ColumnConfig columnConfig = this.columnConfigList.get(i);
if(columnConfig.getColumnFlag() == ColumnFlag.Meta) {
continue;
}
CorrelationWritable cw = CorrelationMultithreadedMapper.finalCorrelationMap.get(i);
synchronized(cw) {
cw.setColumnIndex(i);
cw.setCount(cw.getCount() + 1d);
cw.setSum(cw.getSum() + dValues[i]);
double squaredSum = dValues[i] * dValues[i];
cw.setSumSquare(cw.getSumSquare() + squaredSum);
double[] xySum = cw.getXySum();
if(xySum == null) {
xySum = new double[this.columnConfigList.size()];
cw.setXySum(xySum);
}
double[] xxSum = cw.getXxSum();
if(xxSum == null) {
xxSum = new double[this.columnConfigList.size()];
cw.setXxSum(xxSum);
}
double[] yySum = cw.getYySum();
if(yySum == null) {
yySum = new double[this.columnConfigList.size()];
cw.setYySum(yySum);
}
double[] adjustCount = cw.getAdjustCount();
if(adjustCount == null) {
adjustCount = new double[this.columnConfigList.size()];
cw.setAdjustCount(adjustCount);
}
double[] adjustSumX = cw.getAdjustSumX();
if(adjustSumX == null) {
adjustSumX = new double[this.columnConfigList.size()];
cw.setAdjustSumX(adjustSumX);
}
double[] adjustSumY = cw.getAdjustSumY();
if(adjustSumY == null) {
adjustSumY = new double[this.columnConfigList.size()];
cw.setAdjustSumY(adjustSumY);
}
for(int j = 0; j < this.columnConfigList.size(); j++) {
ColumnConfig otherColumnConfig = this.columnConfigList.get(j);
if(otherColumnConfig.getColumnFlag() == ColumnFlag.Meta) {
continue;
}
if(i > j && !this.isComputeAll) {
continue;
}
// only do stats on both valid values
if(dValues[i] != Double.MIN_VALUE && dValues[j] != Double.MIN_VALUE) {
xySum[j] += dValues[i] * dValues[j];
xxSum[j] += squaredSum;
yySum[j] += dValues[j] * dValues[j];
adjustCount[j] += 1d;
adjustSumX[j] += dValues[i];
adjustSumY[j] += dValues[j];
}
}
}
}
}
private double[] getDoubleArrayByRawArray(String[] units) {
double[] dValues = new double[this.columnConfigList.size()];
for(int i = 0; i < this.columnConfigList.size(); i++) {
ColumnConfig columnConfig = this.columnConfigList.get(i);
if(columnConfig.getColumnFlag() == ColumnFlag.Meta) {
// only meta columns not in correlation
dValues[i] = 0d;
} else if(columnConfig.getColumnFlag() == ColumnFlag.Target) {
if(this.tagSet.contains(units[i])) {
if(modelConfig.isRegression()) {
if(this.posTagSet.contains(units[i])) {
dValues[i] = 1d;
}
if(this.negTagSet.contains(units[i])) {
dValues[i] = 0d;
}
} else {
int index = -1;
for(int j = 0; j < tags.size(); j++) {
Set<String> tagSet = tags.get(j);
if(tagSet.contains(units[0])) {
index = j;
break;
}
}
dValues[i] = index;
}
} else {
// Invalid target
dValues[i] = Double.MIN_VALUE;
}
} else {
if(columnConfig.isNumerical()) {
// if missing it is set to MIN_VALUE, then try to skip rows with invalid value
if(units[i] == null || units[i].length() == 0) {
// some null values, set it to min value to avoid parsing String to improve performance
dValues[i] = Double.MIN_VALUE;
} else {
dValues[i] = NumberFormatUtils.getDouble(units[i], Double.MIN_VALUE);
}
}
if(columnConfig.isCategorical()) {
if(columnConfig.getBinCategory() == null) {
if(System.currentTimeMillis() % 100L == 0) {
LOG.warn(
"Column {} with null binCategory but is not meta or target column, set to 0d for correlation.",
columnConfig.getColumnName());
}
dValues[i] = 0d;
continue;
}
Integer index = null;
if(units[i] != null) {
index = this.categoricalIndexMap.get(columnConfig.getColumnNum()).get(units[i]);
}
if(index == null || index == -1) {
dValues[i] = columnConfig.getBinPosRate().get(columnConfig.getBinPosRate().size() - 1);
} else {
Double binPosRate = columnConfig.getBinPosRate().get(index);
if(binPosRate == null) {
dValues[i] = columnConfig.getBinPosRate().get(columnConfig.getBinPosRate().size() - 1);
} else {
dValues[i] = binPosRate;
}
}
}
}
}
return dValues;
}
/**
* Write column info to reducer for merging.
*/
@Override
protected void cleanup(Context context) throws IOException, InterruptedException {
LOG.info("Final records in such thread of mapper: {}.", count);
}
}