/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.binning; import java.util.ArrayList; import java.util.List; import ml.shifu.shifu.core.binning.obj.LinkNode; import org.apache.commons.lang.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * EqualPopulationBinning class */ public class EqualPopulationBinning extends AbstractBinning<Double> { private final static Logger log = LoggerFactory.getLogger(EqualPopulationBinning.class); /** * The default scale for generating histogram is for keep accuracy. * General speaking, larger scale will guarantee better accuracy. But it will also cause worse efficiency * * TODO here to make it computable with expected bin num, 100 * 10 = 1000, if set bin num to 100, this should not be * 100 because of bad performance. */ public static final int HIST_SCALE = 100; /** * The maximum histogram unit count that could be hold */ private int maxHistogramUnitCnt; /** * Current histogram unit count in histogram */ private int currentHistogramUnitCnt; /** * The header and tail of histogram */ private LinkNode<HistogramUnit> header, tail; /** * Empty constructor : it is just for bin merging */ protected EqualPopulationBinning() { } /** * Construct @EqualPopulationBinning with expected bin number * * @param binningNum * the binningNum */ public EqualPopulationBinning(int binningNum) { this(binningNum, null); } /** * Construct @EqualPopulationBinning with expected bin number and with histogram scale factor * * @param binningNum * the binningNum * @param histogramScale * the histogram scale */ public EqualPopulationBinning(int binningNum, int histogramScale) { this(binningNum, null); this.maxHistogramUnitCnt = super.expectedBinningNum * histogramScale; } /** * Construct @@EqualPopulationBinning with expected bin number and * values list that would be treated as missing value * * @param binningNum * the binningNum * @param missingValList * the missingValList */ public EqualPopulationBinning(int binningNum, List<String> missingValList) { super(binningNum); this.maxHistogramUnitCnt = super.expectedBinningNum * HIST_SCALE; this.currentHistogramUnitCnt = 0; this.header = null; this.tail = null; } /** * Add the value (in format of text) into histogram with frequency 1. * First of all the input string will be trimmed and check whether it is missing value or not * If it is missing value, the missing value count will +1 * After that, the input string will be parsed into double. If it is not a double, invalid value count will +1 * * @see ml.shifu.shifu.core.binning.AbstractBinning#addData(java.lang.String) */ @Override public void addData(String val) { String fval = StringUtils.trimToEmpty(val); if(!isMissingVal(fval)) { double dval = 0; try { dval = Double.parseDouble(fval); } catch (NumberFormatException e) { // not a number? just ignore super.incInvalidValCnt(); return; } process(dval, 1); } else { super.incMissingValCnt(); } } /** * Add the value (in format of text) into histogram with weight. * First of all the input string will be trimmed and check whether it is missing value or not * If it is missing value, the missing value count will +1 * After that, the input string will be parsed into double. If it is not a double, invalid value count will +1 * * @param val * , string type value * @param wVal * , frequency or weight of this value */ public void addData(String val, double wVal) { String fval = StringUtils.trimToEmpty(val); if(!isMissingVal(fval)) { double dval = 0; try { dval = Double.parseDouble(fval); } catch (NumberFormatException e) { // not a number? just ignore super.incInvalidValCnt(); return; } process(dval, wVal); } else { super.incMissingValCnt(); } } /** * Add a value into histogram with frequency 1. * * @param val * the value to be added */ public void addData(double val) { process(val, 1); } /** * Add a value into histogram with frequency. * * @param val * the value to be added * @param frequency * the weight */ public void addData(double val, double frequency) { process(val, frequency); } /* * Generate data bin by expected bin number * * @see ml.shifu.shifu.core.binning.AbstractBinning#getDataBin() */ @Override public List<Double> getDataBin() { return getDataBin(super.expectedBinningNum); } /** * Get the median value in the histogram * * @return median value */ public Double getMedian() { List<Double> dataBinning = getDataBin(2); if(dataBinning.size() > 1) { return dataBinning.get(1); } else { return null; } } /** * Generate data bin by expected bin number * * @param toBinningNum * toBinningNum * @return list of data binning */ private List<Double> getDataBin(int toBinningNum) { List<Double> binBorders = new ArrayList<Double>(); binBorders.add(Double.NEGATIVE_INFINITY); if(this.currentHistogramUnitCnt <= toBinningNum) { // if the count of histogram unit is less than expected bin number // return each histogram unit as a bin. The boundary will be middle value // of every two histogram unit values convertHistogramUnitIntoBin(binBorders); return binBorders; } double totalCnt = getTotalInHistogram(); LinkNode<HistogramUnit> currStartPos = null; for(int j = 1; j < toBinningNum; j++) { double s = (j * totalCnt) / toBinningNum; LinkNode<HistogramUnit> pos = locateHistogram(s, currStartPos); if(pos == null || pos == currStartPos) { continue; } else { HistogramUnit chu = pos.data(); HistogramUnit nhu = pos.next().data(); double d = s - sum(chu.getHval()); if(d < 0) { double u = (chu.getHval() + nhu.getHval()) / 2; binBorders.add(u); currStartPos = pos; continue; } double a = nhu.getHcnt() - chu.getHcnt(); double b = 2 * chu.getHcnt(); double c = -2 * d; double z = 0.0; if(Double.compare(a, 0) == 0) { z = -1 * c / b; } else { z = (-1 * b + Math.sqrt(b * b - 4 * a * c)) / (2 * a); } double u = chu.getHval() + (nhu.getHval() - chu.getHval()) * z; binBorders.add(u); currStartPos = pos; } } return binBorders; } private void convertHistogramUnitIntoBin(List<Double> binBorders) { LinkNode<HistogramUnit> tmp = this.header; while(tmp != this.tail) { HistogramUnit chu = tmp.data(); HistogramUnit nhu = tmp.next().data(); binBorders.add((chu.getHval() + nhu.getHval()) / 2); tmp = tmp.next(); } } /** * Get the total value count in histogram * * @return total hosto value */ private double getTotalInHistogram() { double total = 0; LinkNode<HistogramUnit> tmp = this.header; while(tmp != null) { total += tmp.data().getHcnt(); tmp = tmp.next(); } return total; } /** * Locate histogram unit with just less than s, from some histogram unit * * @param s * the s value * @param startPos * start pos * @return next node */ private LinkNode<HistogramUnit> locateHistogram(double s, LinkNode<HistogramUnit> startPos) { while(startPos != this.tail) { if(startPos == null) { startPos = this.header; } HistogramUnit chu = startPos.data(); HistogramUnit nhu = startPos.next().data(); double sc = sum(chu.getHval()); double sn = sum(nhu.getHval()); if(sc >= s || (sc < s && s <= sn)) { return startPos; } startPos = startPos.next(); } return null; } /** * Sum the histogram's frequency whose value less than or equal some value * * @param hval * the h value * @return current sum */ private double sum(double hval) { LinkNode<HistogramUnit> posHistogramUnit = null; LinkNode<HistogramUnit> tmp = this.header; while(tmp != this.tail) { HistogramUnit chu = tmp.data(); HistogramUnit nhu = tmp.next().data(); if(chu.getHval() <= hval && hval < nhu.getHval()) { posHistogramUnit = tmp; break; } tmp = tmp.next(); } if(posHistogramUnit != null) { HistogramUnit chu = posHistogramUnit.data(); HistogramUnit nhu = posHistogramUnit.next().data(); double mb = chu.getHcnt() + (nhu.getHcnt() - nhu.getHcnt()) * (hval - chu.getHval()) / (nhu.getHval() - chu.getHval()); double s = (chu.getHcnt() + mb) * (hval - chu.getHval()) / (nhu.getHval() - chu.getHval()); s = s / 2; tmp = this.header; while(tmp != posHistogramUnit) { HistogramUnit hu = tmp.data(); s = s + hu.getHcnt(); tmp = tmp.next(); } return s + chu.getHcnt() / 2d; } else if(tmp == this.tail) { double sum = 0.0; tmp = this.header; while(tmp != null) { sum += tmp.data().getHcnt(); tmp = tmp.next(); } return sum; } return -1.0; } /** * Process the histogram with value and frequency * * @param dval * the d value * @param frequency * the weight */ private void process(double dval, double frequency) { LinkNode<HistogramUnit> node = new LinkNode<HistogramUnit>(new HistogramUnit(dval, frequency)); if(this.tail == null && this.maxHistogramUnitCnt > 1) { this.header = node; this.tail = node; this.currentHistogramUnitCnt = 1; } else { insertWithTrim(node); } } /** * Insert one @HistogramUnit node into the histogram. * Meanwhile it will try to keep the histogram as most @maxHistogramUnitCnt * So when inserting one node in, the method will try to find the place to insert as well as minimum interval * * @param node * current node */ private void insertWithTrim(LinkNode<HistogramUnit> node) { LinkNode<HistogramUnit> insertOpsUnit = null; LinkNode<HistogramUnit> minIntervalOpsUnit = null; Double minInterval = Double.MAX_VALUE; LinkNode<HistogramUnit> tmp = this.tail; while(tmp != null) { if(insertOpsUnit == null) { int res = Double.compare(tmp.data().getHval(), node.data().getHval()); if(res > 0) { // do nothing } else if(res == 0) { tmp.data().setHcnt(tmp.data().getHcnt() + node.data().getHcnt()); return; } else if(res < 0) { // find the right insert position to insert insertOpsUnit = tmp; double interval = node.data().getHval() - tmp.data().getHval(); if(interval < minInterval) { minInterval = interval; minIntervalOpsUnit = tmp; } if(tmp.next() != null) { interval = tmp.next().data().getHval() - node.data().getHval(); if(interval < minInterval) { minInterval = interval; minIntervalOpsUnit = node; } } } } if(tmp.next() != null) { LinkNode<HistogramUnit> next = tmp.next(); double interval = next.data().getHval() - tmp.data().getHval(); if(interval < minInterval) { minInterval = interval; minIntervalOpsUnit = tmp; } } tmp = tmp.prev(); } // insert node into linked list if(insertOpsUnit == null) { // insert as the first node if(this.header != null) { this.header.setPrev(node); } node.setNext(this.header); this.header = node; if(this.tail == null) { this.tail = node; } } else if(insertOpsUnit == this.tail) { // insert as the last node node.setPrev(insertOpsUnit); insertOpsUnit.setNext(node); this.tail = node; } else { // some intermediate node node.setNext(insertOpsUnit.next()); node.setPrev(insertOpsUnit); insertOpsUnit.next().setPrev(node); insertOpsUnit.setNext(node); } // merge info into next node if(this.currentHistogramUnitCnt == this.maxHistogramUnitCnt) { LinkNode<HistogramUnit> nextNode = minIntervalOpsUnit.next(); HistogramUnit chu = minIntervalOpsUnit.data(); HistogramUnit nhu = nextNode.data(); nhu.setHval((chu.getHval() * chu.getHcnt() + nhu.getHval() * nhu.getHcnt()) / (chu.getHcnt() + nhu.getHcnt())); nhu.setHcnt(chu.getHcnt() + nhu.getHcnt()); removeCurrentNode(minIntervalOpsUnit, nextNode); } else { this.currentHistogramUnitCnt++; } } private void removeCurrentNode(LinkNode<HistogramUnit> currNode, LinkNode<HistogramUnit> nextNode) { // remove current node if(currNode == this.header) { nextNode.setPrev(null); this.header = nextNode; } else { LinkNode<HistogramUnit> prev = currNode.prev(); prev.setNext(nextNode); nextNode.setPrev(prev); } } @Override public void mergeBin(AbstractBinning<?> another) { EqualPopulationBinning binning = (EqualPopulationBinning) another; super.mergeBin(binning); LinkNode<HistogramUnit> tmp = binning.header; while(tmp != null) { this.insertWithTrim(new LinkNode<HistogramUnit>(tmp.data())); tmp = tmp.next(); } } protected void stringToObj(String objValStr) { super.stringToObj(objValStr); String[] objStrArr = objValStr.split(Character.toString(FIELD_SEPARATOR), -1); maxHistogramUnitCnt = Integer.parseInt(objStrArr[4]); if(objStrArr.length > 5 && StringUtils.isNotBlank(objStrArr[5])) { String[] histogramStrArr = objStrArr[5].split(Character.toString(SETLIST_SEPARATOR), -1); for(String histogramStr: histogramStrArr) { HistogramUnit hu = HistogramUnit.stringToObj(histogramStr); this.insertWithTrim(new LinkNode<HistogramUnit>(hu)); } } else { log.warn("Empty categorical bin - " + objValStr); } } public String objToString() { List<String> histogramStrList = new ArrayList<String>(); if(this.header != null) { LinkNode<HistogramUnit> tmp = this.header; while(tmp != null) { histogramStrList.add(tmp.data().objToString()); tmp = tmp.next(); } } return super.objToString() + Character.toString(FIELD_SEPARATOR) + Integer.toString(maxHistogramUnitCnt) + Character.toString(FIELD_SEPARATOR) + StringUtils.join(histogramStrList, SETLIST_SEPARATOR); } /** * HistogramUnit class is the unit for histogram */ public static class HistogramUnit implements Comparable<HistogramUnit> { private double hval; private double hcnt; public HistogramUnit(double hval, double hcnt) { this.hval = hval; this.hcnt = hcnt; } public double getHval() { return hval; } public void setHval(double hval) { this.hval = hval; } public double getHcnt() { return hcnt; } public void setHcnt(double hcnt) { this.hcnt = hcnt; } /* * (non-Javadoc) * * @see java.lang.Comparable#compareTo(java.lang.Object) */ @Override public int compareTo(HistogramUnit another) { return Double.compare(hval, another.getHval()); } /* * (non-Javadoc) * * @see java.lang.Object#hashCode() */ @Override public int hashCode() { final int prime = 31; int result = 1; long temp; temp = Double.doubleToLongBits(hval); result = prime * result + (int) (temp ^ (temp >>> 32)); return result; } /* * (non-Javadoc) * * @see java.lang.Object#equals(java.lang.Object) */ @Override public boolean equals(Object obj) { if(this == obj) return true; if(obj == null) return false; if(!(obj instanceof HistogramUnit)) return false; HistogramUnit other = (HistogramUnit) obj; return Double.compare(hval, other.hval) == 0; } @Override public String toString() { return "[" + hval + ", " + hcnt + "]"; } public String objToString() { return Double.toString(hval) + Character.toString(PAIR_SEPARATOR) + Double.toString(hcnt); } public static HistogramUnit stringToObj(String histogramStr) { String[] fields = StringUtils.split(histogramStr, PAIR_SEPARATOR); return new HistogramUnit(Double.parseDouble(fields[0]), Double.parseDouble(fields[1])); } } }