/*
* Copyright [2012-2015] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.binning;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.*;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.container.obj.RawSourceData.SourceType;
import ml.shifu.shifu.core.ColumnStatsCalculator;
import ml.shifu.shifu.core.ColumnStatsCalculator.ColumnMetrics;
import ml.shifu.shifu.core.autotype.CountAndFrequentItemsWritable;
import ml.shifu.shifu.udf.CalculateStatsUDF;
import ml.shifu.shifu.util.Base64Utils;
import ml.shifu.shifu.util.CommonUtils;
import ml.shifu.shifu.util.Constants;
import org.apache.commons.lang.StringUtils;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Reducer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.clearspring.analytics.stream.cardinality.CardinalityMergeException;
import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus;
/**
* Collect all statistics together in reducer.
*
* <p>
* The same format with previous output to make sure consistent with output processing functions.
*
* <p>
* Only one reducer to make sure all info can be collected together. One reducer is not bottleneck as some times we only
* have thousands of variables.
*/
public class UpdateBinningInfoReducer extends Reducer<IntWritable, BinningInfoWritable, NullWritable, Text> {
private final static Logger LOG = LoggerFactory.getLogger(UpdateBinningInfoReducer.class);
private static final int MAX_CATEGORICAL_BINC_COUNT = 5000;
private static final double EPS = 1e-6;
/**
* Column Config list read from HDFS
*/
private List<ColumnConfig> columnConfigList;
/**
* Prevent too many new objects for output key.
*/
private Text outputValue;
/**
* To concat output string
*/
private StringBuilder sb = new StringBuilder(2000);
/**
* To format double value.
*/
private DecimalFormat df = new DecimalFormat("##.######");
private boolean statsExcludeMissingValue;
/**
* Model Config read from HDFS
*/
private ModelConfig modelConfig;
/**
* Load all configurations for modelConfig and columnConfigList from source type.
*/
private void loadConfigFiles(final Context context) {
try {
SourceType sourceType = SourceType.valueOf(context.getConfiguration().get(
Constants.SHIFU_MODELSET_SOURCE_TYPE, SourceType.HDFS.toString()));
this.modelConfig = CommonUtils.loadModelConfig(
context.getConfiguration().get(Constants.SHIFU_MODEL_CONFIG), sourceType);
this.columnConfigList = CommonUtils.loadColumnConfigList(
context.getConfiguration().get(Constants.SHIFU_COLUMN_CONFIG), sourceType);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* Do initialization like ModelConfig and ColumnConfig loading.
*/
@Override
protected void setup(Context context) throws IOException, InterruptedException {
loadConfigFiles(context);
this.statsExcludeMissingValue = context.getConfiguration().getBoolean(Constants.SHIFU_STATS_EXLCUDE_MISSING,
true);
this.outputValue = new Text();
}
@Override
protected void reduce(IntWritable key, Iterable<BinningInfoWritable> values, Context context) throws IOException,
InterruptedException {
long start = System.currentTimeMillis();
double sum = 0d;
double squaredSum = 0d;
double tripleSum = 0d;
double quarticSum = 0d;
long count = 0L, missingCount = 0L;
double min = Double.MAX_VALUE, max = Double.MIN_VALUE;
List<Double> binBoundaryList = null;
List<String> binCategories = null;
long[] binCountPos = null;
long[] binCountNeg = null;
double[] binWeightPos = null;
double[] binWeightNeg = null;
ColumnConfig columnConfig = this.columnConfigList.get(key.get());
HyperLogLogPlus hyperLogLogPlus = null;
Set<String> fis = new HashSet<String>();
long totalCount = 0, invalidCount = 0, validNumCount = 0;
int binSize = 0;
for(BinningInfoWritable info: values) {
CountAndFrequentItemsWritable cfiw = info.getCfiw();
totalCount += cfiw.getCount();
invalidCount += cfiw.getInvalidCount();
validNumCount += cfiw.getValidNumCount();
fis.addAll(cfiw.getFrequetItems());
if(hyperLogLogPlus == null) {
hyperLogLogPlus = HyperLogLogPlus.Builder.build(cfiw.getHyperBytes());
} else {
try {
hyperLogLogPlus = (HyperLogLogPlus) hyperLogLogPlus.merge(HyperLogLogPlus.Builder.build(cfiw
.getHyperBytes()));
} catch (CardinalityMergeException e) {
throw new RuntimeException(e);
}
}
if(info.isNumeric() && binBoundaryList == null) {
binBoundaryList = info.getBinBoundaries();
binSize = binBoundaryList.size();
binCountPos = new long[binSize + 1];
binCountNeg = new long[binSize + 1];
binWeightPos = new double[binSize + 1];
binWeightNeg = new double[binSize + 1];
}
if(!info.isNumeric() && binCategories == null) {
binCategories = info.getBinCategories();
binSize = binCategories.size();
binCountPos = new long[binSize + 1];
binCountNeg = new long[binSize + 1];
binWeightPos = new double[binSize + 1];
binWeightNeg = new double[binSize + 1];
}
count += info.getTotalCount();
missingCount += info.getMissingCount();
// for numeric, such sums are OK, for categorical, such values are all 0, should be updated by using
// binCountPos and binCountNeg
sum += info.getSum();
squaredSum += info.getSquaredSum();
tripleSum += info.getTripleSum();
quarticSum += info.getQuarticSum();
if(Double.compare(max, info.getMax()) < 0) {
max = info.getMax();
}
if(Double.compare(min, info.getMin()) > 0) {
min = info.getMin();
}
for(int i = 0; i < (binSize + 1); i++) {
binCountPos[i] += info.getBinCountPos()[i];
binCountNeg[i] += info.getBinCountNeg()[i];
binWeightPos[i] += info.getBinWeightPos()[i];
binWeightNeg[i] += info.getBinWeightNeg()[i];
}
}
// To merge categorical binning
if ( columnConfig.isCategorical() && modelConfig.getStats().getCateMaxNumBin() > 0 ) {
CateBinningStats cateBinningStats = rebinCategoricalValues(new CateBinningStats(binCategories,
binCountPos, binCountNeg, binWeightPos, binWeightNeg));
LOG.info("For variable - {}, {} bins is rebined to {} bins",
columnConfig.getColumnName(), binCategories.size(), cateBinningStats.binCategories.size());
binCategories = cateBinningStats.binCategories;
binCountPos = cateBinningStats.binCountPos;
binCountNeg = cateBinningStats.binCountNeg;
binWeightPos = cateBinningStats.binWeightPos;
binWeightNeg = cateBinningStats.binWeightNeg;
}
double[] binPosRate;
if(modelConfig.isRegression()) {
binPosRate = computePosRate(binCountPos, binCountNeg);
} else {
// for multiple classfication, use rate of categories to compute a value
binPosRate = computeRateForMultiClassfication(binCountPos);
}
String binBounString = null;
if(columnConfig.isCategorical()) {
if(binCategories.size() < 0 || binCategories.size() > MAX_CATEGORICAL_BINC_COUNT) {
LOG.warn("Column {} {} with invalid bin category size.", key.get(), columnConfig.getColumnName(),
binCategories.size());
return;
}
binBounString = Base64Utils.base64Encode("["
+ StringUtils.join(binCategories, CalculateStatsUDF.CATEGORY_VAL_SEPARATOR) + "]");
// recompute such value for categorical variables
min = Double.MAX_VALUE;
max = Double.MIN_VALUE;
sum = 0d;
squaredSum = 0d;
for(int i = 0; i < binPosRate.length; i++) {
if(!Double.isNaN(binPosRate[i])) {
if(Double.compare(max, binPosRate[i]) < 0) {
max = binPosRate[i];
}
if(Double.compare(min, binPosRate[i]) > 0) {
min = binPosRate[i];
}
long binCount = binCountPos[i] + binCountNeg[i];
sum += binPosRate[i] * binCount;
double squaredVal = binPosRate[i] * binPosRate[i];
squaredSum += squaredVal * binCount;
tripleSum += squaredVal * binPosRate[i] * binCount;
quarticSum += squaredVal * squaredVal * binCount;
}
}
} else {
if(binBoundaryList.size() == 0) {
LOG.warn("Column {} {} with invalid bin boundary size.", key.get(), columnConfig.getColumnName(),
binBoundaryList.size());
return;
}
binBounString = binBoundaryList.toString();
}
ColumnMetrics columnCountMetrics = null;
ColumnMetrics columnWeightMetrics = null;
if(modelConfig.isRegression()) {
columnCountMetrics = ColumnStatsCalculator.calculateColumnMetrics(binCountNeg, binCountPos);
columnWeightMetrics = ColumnStatsCalculator.calculateColumnMetrics(binWeightNeg, binWeightPos);
}
// To make it be consistent with SPDT, missingCount is excluded to compute mean, stddev ...
long realCount = this.statsExcludeMissingValue ? (count - missingCount) : count;
double mean = sum / realCount;
double stdDev = Math.sqrt(Math.abs((squaredSum - (sum * sum) / realCount + EPS) / (realCount - 1)));
double aStdDev = Math.sqrt(Math.abs((squaredSum - (sum * sum) / realCount + EPS) / realCount));
double skewness = ColumnStatsCalculator.computeSkewness(realCount, mean, aStdDev, sum, squaredSum, tripleSum);
double kurtosis = ColumnStatsCalculator.computeKurtosis(realCount, mean, aStdDev, sum, squaredSum, tripleSum,
quarticSum);
sb.append(key.get())
.append(Constants.DEFAULT_DELIMITER)
.append(binBounString)
.append(Constants.DEFAULT_DELIMITER)
.append(Arrays.toString(binCountNeg))
.append(Constants.DEFAULT_DELIMITER)
.append(Arrays.toString(binCountPos))
.append(Constants.DEFAULT_DELIMITER)
.append(Arrays.toString(new double[0]))
.append(Constants.DEFAULT_DELIMITER)
.append(Arrays.toString(binPosRate))
.append(Constants.DEFAULT_DELIMITER)
.append(columnCountMetrics == null ? "" : df.format(columnCountMetrics.getKs()))
.append(Constants.DEFAULT_DELIMITER)
.append(columnWeightMetrics == null ? "" : df.format(columnWeightMetrics.getIv()))
.append(Constants.DEFAULT_DELIMITER)
.append(df.format(max))
.append(Constants.DEFAULT_DELIMITER)
.append(df.format(min))
.append(Constants.DEFAULT_DELIMITER)
.append(df.format(mean))
.append(Constants.DEFAULT_DELIMITER)
.append(df.format(stdDev))
.append(Constants.DEFAULT_DELIMITER)
.append(columnConfig.isCategorical() ? "C" : "N")
.append(Constants.DEFAULT_DELIMITER)
.append(df.format(mean))
.append(Constants.DEFAULT_DELIMITER)
.append(missingCount)
.append(Constants.DEFAULT_DELIMITER)
.append(count)
.append(Constants.DEFAULT_DELIMITER)
.append(missingCount * 1.0d / count)
.append(Constants.DEFAULT_DELIMITER)
.append(Arrays.toString(binWeightNeg))
.append(Constants.DEFAULT_DELIMITER)
.append(Arrays.toString(binWeightPos))
.append(Constants.DEFAULT_DELIMITER)
.append(columnCountMetrics == null ? "" : columnCountMetrics.getWoe())
.append(Constants.DEFAULT_DELIMITER)
.append(columnWeightMetrics == null ? "" : columnWeightMetrics.getWoe())
.append(Constants.DEFAULT_DELIMITER)
.append(columnWeightMetrics == null ? "" : columnWeightMetrics.getKs())
.append(Constants.DEFAULT_DELIMITER)
.append(columnCountMetrics == null ? "" : columnCountMetrics.getIv())
.append(Constants.DEFAULT_DELIMITER)
.append(columnCountMetrics == null ? Arrays.toString(new double[binSize + 1]) : columnCountMetrics
.getBinningWoe().toString())
.append(Constants.DEFAULT_DELIMITER)
.append(columnWeightMetrics == null ? Arrays.toString(new double[binSize + 1]) : columnWeightMetrics
.getBinningWoe().toString()).append(Constants.DEFAULT_DELIMITER).append(skewness)
.append(Constants.DEFAULT_DELIMITER).append(kurtosis).append(Constants.DEFAULT_DELIMITER)
.append(totalCount).append(Constants.DEFAULT_DELIMITER).append(invalidCount)
.append(Constants.DEFAULT_DELIMITER).append(validNumCount).append(Constants.DEFAULT_DELIMITER)
.append(hyperLogLogPlus.cardinality()).append(Constants.DEFAULT_DELIMITER)
.append(limitedFrequentItems(fis));
outputValue.set(sb.toString());
context.write(NullWritable.get(), outputValue);
sb.delete(0, sb.length());
LOG.debug("Time:{}", (System.currentTimeMillis() - start));
}
private static String limitedFrequentItems(Set<String> fis) {
StringBuilder sb = new StringBuilder(200);
int size = Math.min(fis.size(), CountAndFrequentItemsWritable.FREQUET_ITEM_MAX_SIZE * 10);
Iterator<String> iterator = fis.iterator();
int i = 0;
while(i < size) {
String next = iterator.next().replaceAll("\\" + Constants.DEFAULT_DELIMITER, " ").replace(",", " ");
sb.append(next);
if(i != size - 1) {
sb.append(",");
}
i += 1;
}
return sb.toString();
}
private double[] computePosRate(long[] binCountPos, long[] binCountNeg) {
assert binCountPos != null && binCountNeg != null && binCountPos.length == binCountNeg.length;
double[] posRate = new double[binCountPos.length];
for(int i = 0; i < posRate.length; i++) {
if(Double.compare(binCountPos[i] + binCountNeg[i], 0d) != 0) {
// only compute effective pos rate, if /0, don't do it
posRate[i] = binCountPos[i] * 1.0d / (binCountPos[i] + binCountNeg[i]);
}
}
return posRate;
}
private double[] computeRateForMultiClassfication(long[] binCount) {
double[] rate = new double[binCount.length];
double sum = 0d;
for(int i = 0; i < binCount.length; i++) {
sum += binCount[i];
}
for(int i = 0; i < binCount.length; i++) {
if(Double.compare(sum, 0d) != 0) {
rate[i] = binCount[i] * 1.0d / sum;
}
}
return rate;
}
public CateBinningStats rebinCategoricalValues(CateBinningStats cateBinStats) {
List<CategoricalBinInfo> categoricalBinInfos = new ArrayList<CategoricalBinInfo>();
for ( int i = 0; i < cateBinStats.binCategories.size(); i ++ ) {
String cval = cateBinStats.binCategories.get(i);
CategoricalBinInfo binInfo = new CategoricalBinInfo();
List<String> vals = new ArrayList<String>();
vals.add(cval);
binInfo.setValues(vals);
binInfo.setPositiveCnt(cateBinStats.binCountPos[i]);
binInfo.setNegativeCnt(cateBinStats.binCountNeg[i]);
binInfo.setWeightPos(cateBinStats.binWeightPos[i]);
binInfo.setWeightNeg(cateBinStats.binWeightNeg[i]);
categoricalBinInfos.add(binInfo);
}
Collections.sort(categoricalBinInfos);
CateDynamicBinning binInst = new CateDynamicBinning(modelConfig.getStats().getCateMaxNumBin());
List<CategoricalBinInfo> mergedBinInfos = binInst.merge(categoricalBinInfos);
List<String> binCategories = new ArrayList<String>();
long[] binCountPos = new long[mergedBinInfos.size() + 1];
long[] binCountNeg = new long[mergedBinInfos.size() + 1];
double[] binWeightPos = new double[mergedBinInfos.size() + 1];
double[] binWeightNeg = new double[mergedBinInfos.size() + 1];
for ( int i = 0; i < mergedBinInfos.size(); i ++ ) {
CategoricalBinInfo binInfo = mergedBinInfos.get(i);
binCategories.add(StringUtils.join(binInfo.getValues(), '^'));
binCountPos[i] = binInfo.getPositiveCnt();
binCountNeg[i] = binInfo.getNegativeCnt();
binWeightPos[i] = binInfo.getWeightPos();
binWeightNeg[i] = binInfo.getWeightNeg();
}
binCountPos[binCountPos.length - 1] = cateBinStats.binCountPos[cateBinStats.binCountPos.length - 1];
binCountNeg[binCountNeg.length - 1] = cateBinStats.binCountNeg[cateBinStats.binCountNeg.length - 1];
binWeightPos[binWeightPos.length - 1] = cateBinStats.binWeightPos[cateBinStats.binWeightPos.length - 1];
binWeightNeg[binWeightNeg.length - 1] = cateBinStats.binWeightNeg[cateBinStats.binWeightNeg.length - 1];
return new CateBinningStats(binCategories, binCountPos, binCountNeg, binWeightPos, binWeightNeg);
}
public static class CateBinningStats {
List<String> binCategories = null;
long[] binCountPos = null;
long[] binCountNeg = null;
double[] binWeightPos = null;
double[] binWeightNeg = null;
public CateBinningStats(List<String> binCategories,
long[] binCountPos,long[] binCountNeg,double[] binWeightPos, double[] binWeightNeg) {
this.binCategories = binCategories;
this.binCountPos = binCountPos;
this.binCountNeg = binCountNeg;
this.binWeightPos = binWeightPos;
this.binWeightNeg = binWeightNeg;
}
}
}