/* * Copyright [2013-2015] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.binning; import ml.shifu.shifu.core.binning.obj.NumBinInfo; import java.util.ArrayList; import java.util.Iterator; import java.util.LinkedList; import java.util.List; /** * Created by zhanhu on 7/6/16. */ public class DynamicBinning extends AbstractBinning<Double> { private final static double EPS = 1e-6; private List<NumBinInfo> binInfos; public DynamicBinning(List<NumBinInfo> binInfos, int expectedBinNum) { super(expectedBinNum); this.binInfos = binInfos; } @Override public void addData(String val) { // Do nothing } @Override public List<Double> getDataBin() { List<NumBinInfo> mergedBinInfos = combineEmptyBin(binInfos); if ( mergedBinInfos.size() > super.expectedBinningNum ) { double totalInstCnt = getTotalInstCount(mergedBinInfos); mergedBinInfos = adjustBinInfos(mergedBinInfos, super.expectedBinningNum, totalInstCnt); } List<Double> retBins = new ArrayList<Double>(); for ( NumBinInfo numBinInfo : mergedBinInfos) { retBins.add(numBinInfo.getLeftThreshold()); } return retBins; } private double getTotalInstCount(List<NumBinInfo> mergedBinInfos) { double total = 0.0; for ( NumBinInfo binInfo : mergedBinInfos ) { total += binInfo.getTotalInstCnt(); } return total; } private List<NumBinInfo> adjustBinInfos(List<NumBinInfo> mergedBinInfos, int expectedBinningNum, double totalInstCnt) { while ( mergedBinInfos.size() > expectedBinningNum ) { int pos = getBestMergeNode(mergedBinInfos, totalInstCnt); if ( pos > 0 ) { mergedBinInfos.get(pos - 1).mergeRight(mergedBinInfos.get(pos)); mergedBinInfos.remove(pos); } else { break; } } return mergedBinInfos; } private int getBestMergeNode(List<NumBinInfo> mergedBinInfos, double totalInstCnt) { double entropy = calculateEntropy(mergedBinInfos, totalInstCnt); double entryReduction = Double.MAX_VALUE; int nodeIndexToMerge = 0; int pos = -1; NumBinInfo current = null; Iterator<NumBinInfo> iterator = mergedBinInfos.iterator(); if ( iterator.hasNext() ) { pos = 0; current = iterator.next(); } while (iterator.hasNext()) { pos ++; NumBinInfo next = iterator.next(); NumBinInfo temp = current.clone(); temp.mergeRight(next); double entropyMerging = entropy - getInfoValue(current, totalInstCnt) - getInfoValue(next, totalInstCnt) + getInfoValue(temp, totalInstCnt); double reduction = entropyMerging - entropy; if ( reduction < entryReduction ) { nodeIndexToMerge = pos; entryReduction = reduction; } current = next; } return nodeIndexToMerge; } private double calculateEntropy(List<NumBinInfo> mergedBinInfos, double totalInstCnt) { double entropy = 0.0; for ( NumBinInfo binInfo : mergedBinInfos ) { entropy += getInfoValue(binInfo, totalInstCnt); } return entropy; } private double getInfoValue(NumBinInfo numBinInfo, double totalInstCnt) { double percent = numBinInfo.getTotalInstCnt() / totalInstCnt; double positiveRate = (numBinInfo.getPositiveInstCnt() + EPS) / numBinInfo.getTotalInstCnt(); double negativeRate = (numBinInfo.getTotalInstCnt() - numBinInfo.getPositiveInstCnt() + EPS) / numBinInfo.getTotalInstCnt(); return -1 * percent * (positiveRate * log2(positiveRate) + negativeRate * log2(negativeRate)); } private double log2(double ratio) { return Math.log(ratio) / Math.log(2.0d); } private List<NumBinInfo> combineEmptyBin(List<NumBinInfo> binInfos) { int[] mergeIndicator = new int[binInfos.size()]; for ( int i = 0; i < binInfos.size(); i ++ ) { NumBinInfo binInfo = binInfos.get(i); if ( binInfo.getTotalInstCnt() > 0 ) { mergeIndicator[i] = i; } else { int pos = findNearestNonEmptyBinInfo(binInfos, i); if (pos >= 0) { mergeIndicator[i] = pos; } else { // usually it won't happen here mergeIndicator[i] = i; } } } List<NumBinInfo> mergedBinInfos = new LinkedList<NumBinInfo>(); for ( int i = 0; i < mergeIndicator.length; i ++ ) { if ( mergeIndicator[i] == i ) { NumBinInfo binInfo = binInfos.get(i).clone(); // merge left bin info int j = i - 1; while ( j >= 0 && mergeIndicator[j] == i ) { binInfo.setLeftThreshold(binInfos.get(j).getLeftThreshold()); j --; } j = i + 1; while ( j < mergeIndicator.length && mergeIndicator[j] == i ) { binInfo.setRightThreshold(binInfos.get(j).getRightThreshold()); j ++; } mergedBinInfos.add(binInfo); } } return mergedBinInfos; } private int findNearestNonEmptyBinInfo(List<NumBinInfo> binInfos, int i) { int lpos = -1; int rpos = -1; double dl = Double.MAX_VALUE; double dr = Double.MAX_VALUE; int j = i - 1; while ( j >= 0 ) { if ( binInfos.get(j).getTotalInstCnt() > 0 ) { lpos = j; dl = binInfos.get(i).getLeftThreshold() - binInfos.get(j).getRightThreshold(); break; } j--; } j = i + 1; while ( j < binInfos.size() ) { if ( binInfos.get(j).getTotalInstCnt() > 0 ) { rpos = j; dr = binInfos.get(j).getLeftThreshold() - binInfos.get(i).getRightThreshold(); break; } j++; } return ((dl < dr) ? lpos : rpos); } }