package ml.shifu.shifu.core.binning; import java.util.ArrayList; import java.util.Iterator; import java.util.List; /** * Created by zhanhu on 4/18/17. */ public class CateDynamicBinning { private final static double EPS = 1e-6; private int expectedBinningNum; public CateDynamicBinning(int expectedBinNum) { this.expectedBinningNum = expectedBinNum; } public List<CategoricalBinInfo> merge(List<CategoricalBinInfo> categoricalBinInfos) { List<CategoricalBinInfo> mergedBinInfos = new ArrayList<CategoricalBinInfo>(categoricalBinInfos); if (mergedBinInfos.size() > this.expectedBinningNum) { double totalInstCnt = getTotalInstCount(mergedBinInfos); mergedBinInfos = adjustBinInfos(mergedBinInfos, this.expectedBinningNum, totalInstCnt); } return mergedBinInfos; } private double getTotalInstCount(List<CategoricalBinInfo> categoricalBinInfos) { double total = 0.0; for (CategoricalBinInfo binInfo : categoricalBinInfos) { total += (binInfo.getNegativeCnt() + binInfo.getPositiveCnt()); } return total; } private List<CategoricalBinInfo> adjustBinInfos(List<CategoricalBinInfo> mergedBinInfos, int expectedBinningNum, double totalInstCnt) { while (mergedBinInfos.size() > expectedBinningNum) { int pos = getBestMergeNode(mergedBinInfos, totalInstCnt); if (pos > 0) { mergedBinInfos.get(pos - 1).mergeRight(mergedBinInfos.get(pos)); mergedBinInfos.remove(pos); } else { break; } } return mergedBinInfos; } private int getBestMergeNode(List<CategoricalBinInfo> mergedBinInfos, double totalInstCnt) { double entropy = calculateEntropy(mergedBinInfos, totalInstCnt); double entryReduction = Double.MAX_VALUE; int nodeIndexToMerge = 0; int pos = -1; CategoricalBinInfo current = null; Iterator<CategoricalBinInfo> iterator = mergedBinInfos.iterator(); if (iterator.hasNext()) { pos = 0; current = iterator.next(); } while (iterator.hasNext()) { pos++; CategoricalBinInfo next = iterator.next(); CategoricalBinInfo temp = current.clone(); temp.mergeRight(next); double entropyMerging = entropy - getInfoValue(current, totalInstCnt) - getInfoValue(next, totalInstCnt) + getInfoValue(temp, totalInstCnt); double reduction = entropyMerging - entropy; if (reduction < entryReduction) { nodeIndexToMerge = pos; entryReduction = reduction; } current = next; } return nodeIndexToMerge; } private double calculateEntropy(List<CategoricalBinInfo> mergedBinInfos, double totalInstCnt) { double entropy = 0.0; for (CategoricalBinInfo binInfo : mergedBinInfos) { entropy += getInfoValue(binInfo, totalInstCnt); } return entropy; } private double getInfoValue(CategoricalBinInfo cateBinInfo, double totalInstCnt) { double percent = cateBinInfo.getTotalInstCnt() / totalInstCnt; double positiveRate = (cateBinInfo.getPositiveCnt() + EPS) / cateBinInfo.getTotalInstCnt(); double negativeRate = (cateBinInfo.getNegativeCnt() + EPS) / cateBinInfo.getTotalInstCnt(); return -1 * percent * (positiveRate * log2(positiveRate) + negativeRate * log2(negativeRate)); } private double log2(double ratio) { return Math.log(ratio) / Math.log(2.0d); } }