/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core;
import ml.shifu.shifu.container.ValueObject;
import ml.shifu.shifu.container.ValueObject.ValueObjectComparator;
import ml.shifu.shifu.container.obj.ModelStatsConf.BinningMethod;
import ml.shifu.shifu.util.QuickSort;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.Serializable;
import java.util.*;
/**
* Binning, it helps to put data input bins
*/
public class Binning {
/**
* logger
*/
private static Logger log = LoggerFactory.getLogger(Binning.class);
/**
* epsilon, it helps to judge equation between 2 float number or prevent divide 0 exception
*/
private final double EPS = 1e-5;
/**
* auto type threshold, it help to judge if it's categorical or numerical
*/
private Integer autoTypeThreshold = 5;
/**
* merger flag
*/
private Boolean mergeEnabled = true;
/**
* Data type, Numerical/Categorical/Auto
*/
public static enum BinningDataType {
Numerical, Categorical, Auto
}
/**
* positive tags
*/
private List<String> posTags;
/**
* negative tags
*/
private List<String> negTags;
/**
* data type, flag
*/
private BinningDataType dataType;
/**
* Input Data
*/
private List<ValueObject> voList;
/**
* value object size
*/
private Integer voSize;
/**
* Negative Count(Good txn)
*/
private List<Integer> binCountNeg;
/**
* Positive Count(Bad txn)
*/
private List<Integer> binCountPos;
private List<Double> binWeightedNeg;
public List<Double> getBinWeightedNeg() {
return binWeightedNeg;
}
public void setBinWeightedNeg(List<Double> binWeightedNeg) {
this.binWeightedNeg = binWeightedNeg;
}
public List<Double> getBinWeightedPos() {
return binWeightedPos;
}
public void setBinWeightedPos(List<Double> binWeightedPos) {
this.binWeightedPos = binWeightedPos;
}
private List<Double> binWeightedPos;
/**
* Bin Boundary for Numerical Variables
*/
private List<Double> binBoundary = null;
/**
* Bin Category for Categorical Variables
*/
private List<String> binCategory = null;
/**
* Bin Average Score
*/
private List<Integer> binAvgScore = null;
/**
* Bin positive rate
*/
private List<Double> binPosCaseRate = null;
/**
* default numBins
*/
private int expectNumBins = 10;
/**
* actual bins, if the data is not larger than expect
*/
private int actualNumBins = -1;
/**
* default binningMethod
*/
private BinningMethod binningMethod = BinningMethod.EqualPositive;
/**
* Constructor
*
* @param posTags
* The positive tags list, identify the positive tag in voList
* @param negTags
* The negative tags list, identify the negative tag in volist
* @param type
* The data type
* @param voList
* Value object list
*/
public Binning(List<String> posTags, List<String> negTags, BinningDataType type, List<ValueObject> voList) {
this.posTags = posTags;
this.negTags = negTags;
this.dataType = type;
this.voList = voList;
this.voSize = voList.size();
binCountNeg = new ArrayList<Integer>();
binCountPos = new ArrayList<Integer>();
binBoundary = new ArrayList<Double>();
binCategory = new ArrayList<String>();
binAvgScore = new ArrayList<Integer>();
binPosCaseRate = new ArrayList<Double>();
this.binWeightedNeg = new ArrayList<Double>();
this.binWeightedPos = new ArrayList<Double>();
// voList is sorted!
// Collections.sort(this.voList, new
// ValueObject.VariableObjectComparator());
}
/**
* setter, the max bins
*
* @param numBins
* the numBins
*/
public void setMaxNumOfBins(int numBins) {
this.expectNumBins = numBins;
}
/**
* setter, the binning method
*
* @param binningMethod
* the binningMethod
*/
public void setBinningMethod(BinningMethod binningMethod) {
this.binningMethod = binningMethod;
}
/**
* Set the the max size of auto type threshold
*
* @param autoTypeThreshold
* the autoTypeThreshold
*/
public void setAutoTypeThreshold(Integer autoTypeThreshold) {
this.autoTypeThreshold = autoTypeThreshold;
}
/**
* Numerical: Raw to Value Conversion happens before Binning
* <p>
* Categorical: Raw to Value Conversion happens after Binning
* <p>
* Start binning method
*/
public void doBinning() {
// Set DataType by the number of different keys.
// If it is lower than the threshold, will be treated as Categorical;
// otherwise as Numerical;
if(dataType.equals(BinningDataType.Auto)) {
int cntRaw = 0;
int cntValue = 0;
for(ValueObject vo: voList) {
if(vo.getValue() != null) {
cntValue++;
} else {
cntRaw++;
}
}
Set<Object> keySet = new HashSet<Object>();
for(ValueObject vo: voList) {
if(vo.getValue() != null) {
keySet.add(vo.getValue());
} else {
keySet.add(vo.getRaw());
}
if(keySet.size() > this.autoTypeThreshold) {
break;
}
}
// log.info("BinningDataType: Auto");
// log.info(" # Numerical: " + cntValue);
// log.info(" # Categorical: " + cntRaw);
// log.info(" # Different Key: " + keySet.size());
if(cntRaw > 0 || keySet.size() <= this.autoTypeThreshold) {
this.dataType = BinningDataType.Categorical;
if(cntValue > 0) {
for(ValueObject vo: voList) {
if(vo.getRaw() == null) {
vo.setRaw(vo.getValue().toString());
}
}
}
// log.info("FinalType: Categorical");
} else {
this.dataType = BinningDataType.Numerical;
// log.info("FinalType: Numerical");
}
}
if(dataType.equals(BinningDataType.Categorical)) {
doCategoricalBinning();
} else if(dataType.equals(BinningDataType.Numerical)) {
doNumericalBinning();
}
}
/**
* Start numerical binnning
* </p>
* BinBoundary: left, inclusive
*/
private void doNumericalBinning() {
log.debug("==> There are " + voList.size() + " to sort in Binning.");
long timestamp = System.currentTimeMillis();
// use our in-place quick order
QuickSort.sort(voList, new ValueObjectComparator(BinningDataType.Numerical));
// Collections.sort(voList, new ValueObjectComparator(BinningDataType.Numerical));
log.debug("==> Spend " + (System.currentTimeMillis() - timestamp) + " milli-seconds to sort data.");
if(BinningMethod.EqualPositive.equals(binningMethod)) {
doEqualPositiveBinning();
} else if(BinningMethod.EqualTotal.equals(binningMethod)) {
doEqualTotalBinning();
}
}
/**
* equal bad binning
*/
private void doEqualPositiveBinning() {
int sumBad = 0;
for(int i = 0; i < voSize; i++) {
sumBad += (posTags.contains(voList.get(i).getTag()) ? 1 : 0);
}
int binSize = (int) Math.ceil((double) sumBad / (double) expectNumBins);
int currBin = 0;
// double currBinSumScore = 0;
Integer[] countNeg = new Integer[expectNumBins];
Integer[] countPos = new Integer[expectNumBins];
Double[] countWeightedNeg = new Double[expectNumBins];
Double[] countWeightedPos = new Double[expectNumBins];
countNeg[0] = 0;
countPos[0] = 0;
countWeightedNeg[0] = 0.0;
countWeightedPos[0] = 0.0;
// add first bin (from negative infinite)
this.binBoundary.add(Double.NEGATIVE_INFINITY);
ValueObject vo = null;
double prevData = voList.get(0).getValue();
// For each Variable
for(int i = 0; i < voSize; i++) {
vo = voList.get(i);
double currData = vo.getValue();
// currBinSumScore += vo.getScore();
// current bin is full
if(countPos[currBin] >= binSize) { // vo.getTag() != 0 &&
// still have some negative leftover
if(currBin == expectNumBins - 1 && i != voList.size() - 1) {
continue;
}
// and data is different from the previous pair
if(i == 0 || (mergeEnabled == true && Math.abs(currData - prevData) > EPS) || mergeEnabled == false) {
// BEFORE move to the new bin
// this.binAvgScore.add(currBinSumScore / (countNeg[currBin]
// + countPos[currBin]));
// MOVE to the new bin, if not the last vo
if(i == voList.size() - 1) {
break;
}
currBin++;
this.binBoundary.add(currData);
// AFTER move to the new bin
// currBinSumScore = 0;
countNeg[currBin] = 0;
countPos[currBin] = 0;
countWeightedNeg[currBin] = 0.0;
countWeightedPos[currBin] = 0.0;
}
}
// increment the counter of the current bin
if(negTags.contains(voList.get(i).getTag())) {
countNeg[currBin]++;
countWeightedNeg[currBin] += vo.getWeight();
} else {
countPos[currBin]++;
countWeightedPos[currBin] += vo.getWeight();
}
prevData = currData;
}
// Finishing...
// this.binBoundary.add(vo.getNumericalData());
// this.binAvgScore.add(currBinSumScore / (countNeg[currBin] +
// countPos[currBin]));
this.actualNumBins = currBin + 1;
for(int i = 0; i < this.actualNumBins; i++) {
binCountNeg.add(countNeg[i]);
binCountPos.add(countPos[i]);
binAvgScore.add(0);
binPosCaseRate.add((double) countPos[i] / (countPos[i] + countNeg[i]));
this.binWeightedNeg.add(countWeightedNeg[i]);
this.binWeightedPos.add(countWeightedPos[i]);
}
}
/**
* equal total binning
*/
private void doEqualTotalBinning() {
@SuppressWarnings("unused")
int cntTotal = 0;
int bin = 0;
int cntValidValue = 0;
int cntPos = 0;
int cntNeg = 0;
double cntWeightedPos = 0.0, cntWeightedNeg = 0.0;
boolean isFull = false;
// Add initial bin left boundary: -infinity
binBoundary.add(Double.NEGATIVE_INFINITY);
for(ValueObject vo: voList) {
if(posTags.contains(vo.getTag()) || negTags.contains(vo.getTag())) {
cntValidValue += 1;
}
}
int cntCumTotal = 0;
for(ValueObject vo: voList) {
// Pre-processing: if bin is full, add binBoundary
if(isFull) {
binBoundary.add(vo.getValue());
isFull = false;
}
// Core: push into bin or skip
if(posTags.contains(vo.getTag())) {
cntPos++;
cntWeightedPos += vo.getWeight();
cntTotal += 1;
cntCumTotal += 1;
} else if(negTags.contains(vo.getTag())) {
cntNeg++;
cntWeightedNeg += vo.getWeight();
cntTotal += 1;
cntCumTotal += 1;
} else {
// skip
}
// Post-processing: if bin is full, update related fields
if((double) cntCumTotal / (double) cntValidValue >= (double) (bin + 1) / (double) expectNumBins) {
// Bin is Full
isFull = true;
binCountPos.add(cntPos);
binCountNeg.add(cntNeg);
binAvgScore.add(0);
binWeightedNeg.add(cntWeightedNeg);
binWeightedPos.add(cntWeightedPos);
binPosCaseRate.add((double) binCountPos.get(bin) / (binCountPos.get(bin) + binCountNeg.get(bin)));
bin++;
cntTotal = 0;
cntPos = 0;
cntNeg = 0;
cntWeightedNeg = 0.0;
cntWeightedPos = 0.0;
}
}
}
/**
* if map contain key, the value increase 1
*
* @param map
* @param key
*/
private void incMapCnt(Map<String, Integer> map, String key) {
int cnt = map.containsKey(key) ? map.get(key) : 0;
map.put(key, cnt + 1);
}
private void incMapWithValue(Map<String, Double> map, String key, Double value) {
double num = map.containsKey(key) ? map.get(key) : 0.0;
map.put(key, num + value);
}
/**
* categorical binning
*/
private void doCategoricalBinning() {
// In JDK1.6, the sort action will copy the whole array. That's memory consuming
// For categorical variable, it's not necessary to sort the data
// Collections.sort(voList, new ValueObjectComparator(BinningDataType.Categorical));
Map<String, Integer> categoryHistNeg = new HashMap<String, Integer>();
Map<String, Integer> categoryHistPos = new HashMap<String, Integer>();
Map<String, Double> categoryWeightedNeg = new HashMap<String, Double>();
Map<String, Double> categoryWeightedPos = new HashMap<String, Double>();
Set<String> categorySet = new HashSet<String>();
// Map<String, Double> categoryScoreMap = new HashMap<String, Double>();
for(int i = 0; i < voSize; i++) {
String category = voList.get(i).getRaw();
categorySet.add(category);
// Double score = categoryScoreMap.containsKey(category) ? categoryScoreMap.get(category) : 0;
// categoryScoreMap.put(category, score + voList.get(i).getScore());
if(negTags.contains(voList.get(i).getTag())) {
incMapCnt(categoryHistNeg, category);
incMapWithValue(categoryWeightedNeg, category, voList.get(i).getWeight());
} else {
incMapCnt(categoryHistPos, category);
incMapWithValue(categoryWeightedPos, category, voList.get(i).getWeight());
}
}
Map<String, Double> categoryFraudRateMap = new HashMap<String, Double>();
for(String key: categorySet) {
double cnt0 = categoryHistNeg.containsKey(key) ? categoryHistNeg.get(key) : 0;
double cnt1 = categoryHistPos.containsKey(key) ? categoryHistPos.get(key) : 0;
double rate;
if(Double.compare(cnt0 + cnt1, 0) == 0) {
rate = 0;
} else {
rate = cnt1 / (cnt0 + cnt1);
}
categoryFraudRateMap.put(key, rate);
}
// Sort map
MapComparator cmp = new MapComparator(categoryFraudRateMap);
Map<String, Double> sortedCategoryFraudRateMap = new TreeMap<String, Double>(cmp);
sortedCategoryFraudRateMap.putAll(categoryFraudRateMap);
for(Map.Entry<String, Double> entry: sortedCategoryFraudRateMap.entrySet()) {
String key = entry.getKey();
Integer countNeg = categoryHistNeg.containsKey(key) ? categoryHistNeg.get(key) : 0;
binCountNeg.add(countNeg);
Integer countPos = categoryHistPos.containsKey(key) ? categoryHistPos.get(key) : 0;
binCountPos.add(countPos);
Double weightedNeg = categoryWeightedNeg.containsKey(key) ? categoryWeightedNeg.get(key) : 0.0;
this.binWeightedNeg.add(weightedNeg);
Double weightedPos = categoryWeightedPos.containsKey(key) ? categoryWeightedPos.get(key) : 0.0;
this.binWeightedPos.add(weightedPos);
// use zero, the average score is calculate in post-process
binAvgScore.add(0);
binCategory.add(key);
binPosCaseRate.add(entry.getValue());
}
this.actualNumBins = binCategory.size();
for(ValueObject vo: voList) {
String key = vo.getRaw();
// TODO: Delete this after categorical data is correctly labeled.
if(binCategory.indexOf(key) == -1) {
vo.setValue(0.0);
} else {
// --- end deletion ---
vo.setValue(binPosCaseRate.get(binCategory.indexOf(key)));
}
}
}
/**
* getter the number of bins
*
* @return the actual number of bins
*/
public int getNumBins() {
return this.actualNumBins;
}
/**
* get the bin boundary, from negative infinite to the max of value object
*
* @return the bins boundary
*/
public List<Double> getBinBoundary() {
return this.binBoundary;
}
/**
* get the bin category
*
* @return bin category
*/
public List<String> getBinCategory() {
return this.binCategory;
}
/**
* @return get the negative bin count
*/
public List<Integer> getBinCountNeg() {
return binCountNeg;
}
/**
* @return get the positive bin count
*/
public List<Integer> getBinCountPos() {
return binCountPos;
}
/**
* @return get the average score list
*/
public List<Integer> getBinAvgScore() {
return binAvgScore;
}
/**
* @return get the positive rate for kv
*/
public List<Double> getBinPosCaseRate() {
return binPosCaseRate;
}
/**
* @return get the volist
*/
public List<ValueObject> getUpdatedVoList() {
for(ValueObject vo: voList) {
if(vo.getValue() == null) {
log.error("Not Updated yet.");
return null;
}
}
return voList;
}
/**
* comparator for map
*/
private static class MapComparator implements Comparator<String>, Serializable {
private static final long serialVersionUID = 2178035954425107063L;
Map<String, Double> base;
public MapComparator(Map<String, Double> base) {
this.base = base;
}
public int compare(String a, String b) {
return base.get(a).compareTo(base.get(b));
}
}
/**
* @return get the data type
*/
public BinningDataType getUpdatedDataType() {
return dataType;
}
/*
* set the merge flag
*/
public void setMergeEnabled(Boolean mergeEnabled) {
this.mergeEnabled = mergeEnabled;
}
}