/* * Copyright [2013-2015] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dtrain.dt; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Set; import ml.shifu.shifu.container.obj.ColumnConfig; /** * Different {@link #Impurity()} strategies to compute impurity and gain for each tree node. * * <p> * {@link Entropy} and {@link Gini} are mostly for classification while {@link Variance} is for regression. * * <p> * TODO For categorical feature, do a shuffle in {@link #computeImpurity(double[], ColumnConfig)} firstly, sort by * centroid then. * * @author Zhang David (pengzhang@paypal.com) */ public abstract class Impurity { /** * # of values collected, for example in {@link Variance}, count, sum and squaredSum are collected, statsSize is 3. * For Gini and Entropy, each class, count are selected, for binary classification, statsSize is 2. */ protected int statsSize; /** * Per node, min instances, if less than this value, such gain info will be ignored. */ protected int minInstancesPerNode = 1; /** * Min info gain, if less than this value, such gain info will be ignored. */ protected double minInfoGain = 0d; /** * Compute impurity by feature statistics. Stats array are for all bins. * * @param stats * the stats array * @param confg * column config instance * @return gain info based on stats */ public abstract GainInfo computeImpurity(double[] stats, ColumnConfig confg); /** * Update bin stats value per feature. * * @param featuerStatistic * the stats array * @param binIndex * the bin index * @param label * the label * @param significance * the significance * @param weight * the weight */ public abstract void featureUpdate(double[] featuerStatistic, int binIndex, float label, float significance, float weight); /** * @return the statsSize */ public int getStatsSize() { return statsSize; } /** * @param statsSize * the statsSize to set */ public void setStatsSize(int statsSize) { this.statsSize = statsSize; } } /** * Variance impurity value is computed by ((sumSquare - (sum * sum) / count) / count). * * @author Zhang David (pengzhang@paypal.com) */ class Variance extends Impurity { public Variance() { // 3 are count, sum and sumSquare super.statsSize = 3; } public Variance(int minInstancesPerNode, double minInfoGain) { super.statsSize = 3; super.minInstancesPerNode = minInstancesPerNode; super.minInfoGain = minInfoGain; } @Override public GainInfo computeImpurity(double[] stats, ColumnConfig config) { double count = 0d, sum = 0d, sumSquare = 0d; int binSize = stats.length / super.statsSize; for(int i = 0; i < binSize; i++) { count += stats[i * super.statsSize]; sum += stats[i * super.statsSize + 1]; sumSquare += stats[i * super.statsSize + 2]; } double impurity = getImpurity(count, sum, sumSquare); Predict predict = new Predict(count == 0d ? 0d : sum / count); double leftCount = 0d, leftSum = 0d, leftSumSquare = 0d; double rightCount = 0d, rightSum = 0d, rightSumSquare = 0d; List<GainInfo> internalGainList = new ArrayList<GainInfo>(); Set<Short> leftCategories = config.isCategorical() ? new SimpleBitSet<Short>(config.getBinCategory().size() + 1) : null; List<Pair> categoricalOrderList = null; if(config.isCategorical()) { // sort by predict and then pick the best split categoricalOrderList = getCategoricalOrderList(stats, binSize); } int leftCategorySetSize = 0; for(int i = 0; i < (binSize - 1); i++) { int index = i; if(config.isCategorical()) { index = categoricalOrderList.get(i).index; } leftCount += stats[index * super.statsSize]; leftSum += stats[index * super.statsSize + 1]; leftSumSquare += stats[index * super.statsSize + 2]; rightCount = count - leftCount; rightSum = sum - leftSum; rightSumSquare = sumSquare - leftSumSquare; if(leftCount <= minInstancesPerNode || rightCount <= minInstancesPerNode) { continue; } double leftWeight = leftCount / count; double rightWeight = rightCount / count; double leftImpurity = getImpurity(leftCount, leftSum, leftSumSquare); double rightImpurity = getImpurity(rightCount, rightSum, rightSumSquare); double gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity; if(gain <= minInfoGain) { continue; } Split split = null; if(config.isCategorical()) { // cast to short is safe as we limit max bin size to Short.MAX_VALUE while may be not good for scale if(index >= config.getBinCategory().size()) { // missing value bin, all missing value will be replaced by empty string in norm step leftCategories.add((short) (config.getBinCategory().size())); } else { leftCategories.add((short) index); } leftCategorySetSize += 1; boolean isLeft = true; Set<Short> rightCategories = null; if(config.getBinCategory().size() + 1 <= leftCategorySetSize * 2) { // too many in left, use right; isLeft = false; rightCategories = new SimpleBitSet<Short>(config.getBinCategory().size() + 1); for(short j = 0; j < (config.getBinCategory().size() + 1); j++) { if(!leftCategories.contains(j)) { rightCategories.add(j); } } } // new hash set to copy a new one avoid share object issue split = new Split(config.getColumnNum(), FeatureType.CATEGORICAL, 0d, isLeft, new SimpleBitSet<Short>( config.getBinCategory().size() + 1, (SimpleBitSet<Short>) (isLeft ? leftCategories : rightCategories))); } else { split = new Split(config.getColumnNum(), FeatureType.CONTINUOUS, config.getBinBoundary().get(index + 1), false, null); } Predict leftPredict = new Predict(leftCount == 0d ? 0d : leftSum / leftCount); Predict rightPredict = new Predict(rightCount == 0d ? 0d : rightSum / rightCount); internalGainList.add(new GainInfo(gain, impurity, predict, leftImpurity, rightImpurity, leftPredict, rightPredict, split, count)); } return GainInfo.getGainInfoByMaxGain(internalGainList); } protected List<Pair> getCategoricalOrderList(double[] stats, int binSize) { List<Pair> categoricalOrderList = new ArrayList<Pair>(binSize); for(int i = 0; i < binSize; i++) { // set default = double min to make it sorted at first double binPredict = Double.MIN_VALUE; if(stats[i * super.statsSize] != 0d) { binPredict = stats[i * super.statsSize + 1] / stats[i * super.statsSize]; } // for variance, use predict value to sort categoricalOrderList.add(new Pair(i, binPredict)); } Collections.sort(categoricalOrderList, new Comparator<Pair>() { @Override public int compare(Pair o1, Pair o2) { return Double.valueOf(o1.value).compareTo(Double.valueOf(o2.value)); } }); return categoricalOrderList; } protected double getImpurity(double count, double sum, double sumSquare) { return (count != 0d) ? ((sumSquare - (sum * sum) / count) / count) : 0d; } @Override public void featureUpdate(double[] featuerStatistic, int binIndex, float label, float significance, float weight) { featuerStatistic[binIndex * super.statsSize] += (significance * weight); featuerStatistic[binIndex * super.statsSize + 1] += (label * significance * weight); featuerStatistic[binIndex * super.statsSize + 2] += (label * label * significance * weight); } } /** * Reference from: * * https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_criterion.pyx#L1264 * J. Friedman, Greedy Function Approximation: A Gradient Boosting Machine, The Annals of Statistics, Vol. 29, No. 5, * 2001. * * @author Zhang David (pengzhang@paypal.com) */ class FriedmanMSE extends Variance { public FriedmanMSE() { // 3 are count, sum and sumSquare super.statsSize = 3; } public FriedmanMSE(int minInstancesPerNode, double minInfoGain) { super.statsSize = 3; super.minInstancesPerNode = minInstancesPerNode; super.minInfoGain = minInfoGain; } @Override public GainInfo computeImpurity(double[] stats, ColumnConfig config) { double count = 0d, sum = 0d, sumSquare = 0d; int binSize = stats.length / super.statsSize; for(int i = 0; i < binSize; i++) { count += stats[i * super.statsSize]; sum += stats[i * super.statsSize + 1]; sumSquare += stats[i * super.statsSize + 2]; } double impurity = getImpurity(count, sum, sumSquare); Predict predict = new Predict(count == 0d ? 0d : sum / count); double leftCount = 0d, leftSum = 0d, leftSumSquare = 0d; double rightCount = 0d, rightSum = 0d, rightSumSquare = 0d; List<GainInfo> internalGainList = new ArrayList<GainInfo>(); Set<Short> leftCategories = config.isCategorical() ? new SimpleBitSet<Short>(config.getBinCategory().size() + 1) : null; List<Pair> categoricalOrderList = null; if(config.isCategorical()) { // sort by predict and then pick the best split categoricalOrderList = getCategoricalOrderList(stats, binSize); } int leftCategorySetSize = 0; for(int i = 0; i < (binSize - 1); i++) { int index = i; if(config.isCategorical()) { index = categoricalOrderList.get(i).index; } leftCount += stats[index * super.statsSize]; leftSum += stats[index * super.statsSize + 1]; leftSumSquare += stats[index * super.statsSize + 2]; rightCount = count - leftCount; rightSum = sum - leftSum; rightSumSquare = sumSquare - leftSumSquare; if(leftCount <= minInstancesPerNode || rightCount <= minInstancesPerNode) { continue; } double leftImpurity = getImpurity(leftCount, leftSum, leftSumSquare); double rightImpurity = getImpurity(rightCount, rightSum, rightSumSquare); double diff = rightCount * leftSum - leftCount * rightSum; double gain = (diff * diff) / (leftCount * rightCount * (leftCount + rightCount)); if(gain <= minInfoGain) { continue; } Split split = null; if(config.isCategorical()) { // cast to short is safe as we limit max bin size to Short.MAX_VALUE while may be not good for scale if(index >= config.getBinCategory().size()) { // missing value bin, all missing value will be replaced by empty string in norm step leftCategories.add((short) (config.getBinCategory().size())); } else { leftCategories.add((short) index); } leftCategorySetSize += 1; boolean isLeft = true; Set<Short> rightCategories = null; if(config.getBinCategory().size() + 1 <= leftCategorySetSize * 2) { // too many in left, use right; isLeft = false; rightCategories = new SimpleBitSet<Short>(config.getBinCategory().size() + 1); for(short j = 0; j < (config.getBinCategory().size() + 1); j++) { if(!leftCategories.contains(j)) { rightCategories.add(j); } } } // new hash set to copy a new one avoid share object issue split = new Split(config.getColumnNum(), FeatureType.CATEGORICAL, 0d, isLeft, new SimpleBitSet<Short>( config.getBinCategory().size() + 1, (SimpleBitSet<Short>) (isLeft ? leftCategories : rightCategories))); } else { split = new Split(config.getColumnNum(), FeatureType.CONTINUOUS, config.getBinBoundary().get(index + 1), false, null); } Predict leftPredict = new Predict(leftCount == 0d ? 0d : leftSum / leftCount); Predict rightPredict = new Predict(rightCount == 0d ? 0d : rightSum / rightCount); internalGainList.add(new GainInfo(gain, impurity, predict, leftImpurity, rightImpurity, leftPredict, rightPredict, split, count)); } return GainInfo.getGainInfoByMaxGain(internalGainList); } } /** * Entropy impurity value for classification tree. Entropy formula is SUM(- rate * log(rate) ) * * @author Zhang David (pengzhang@paypal.com) */ class Entropy extends Impurity { public Entropy(int numClasses, int minInstancesPerNode, double minInfoGain) { assert numClasses > 0; super.statsSize = numClasses; super.minInstancesPerNode = minInstancesPerNode; super.minInfoGain = minInfoGain; } @Override public GainInfo computeImpurity(double[] stats, ColumnConfig config) { int numClasses = super.statsSize; double[] statsByClasses = new double[numClasses]; for(int i = 0; i < stats.length / numClasses; i++) { for(int j = 0; j < numClasses; j++) { double oneStatValue = stats[i * super.statsSize + j]; statsByClasses[j] += oneStatValue; } } List<Pair> categoricalOrderList = null; if(config.isCategorical()) { // sort by predict and then pick the best split categoricalOrderList = getCategoricalOrderList(stats, stats.length / super.statsSize); } InternalEntropyInfo info = getEntropyInterInfo(statsByClasses); // prob only effective in binary classes Predict predict = new Predict(info.sumAll == 0d ? 0d : (statsByClasses[1] / info.sumAll), (byte) info.indexOfLargestElement); double[] leftStatByClasses = new double[numClasses]; double[] rightStatByClasses = new double[numClasses]; List<GainInfo> internalGainList = new ArrayList<GainInfo>(); Set<Short> leftCategories = config.isCategorical() ? new SimpleBitSet<Short>(config.getBinCategory().size() + 1) : null; int leftCategorySetSize = 0; for(int i = 0; i < (stats.length / numClasses - 1); i++) { int index = i; if(config.isCategorical()) { index = categoricalOrderList.get(i).index; } for(int j = 0; j < leftStatByClasses.length; j++) { leftStatByClasses[j] += stats[index * numClasses + j]; } InternalEntropyInfo leftInfo = getEntropyInterInfo(leftStatByClasses); Predict leftPredict = new Predict(leftInfo.sumAll == 0d ? 0d : (leftStatByClasses[1] / leftInfo.sumAll), (byte) leftInfo.indexOfLargestElement); for(int j = 0; j < leftStatByClasses.length; j++) { rightStatByClasses[j] = statsByClasses[j] - leftStatByClasses[j]; } InternalEntropyInfo rightInfo = getEntropyInterInfo(rightStatByClasses); if(leftInfo.sumAll <= minInstancesPerNode || rightInfo.sumAll <= minInstancesPerNode) { continue; } Predict rightPredict = new Predict( rightInfo.sumAll == 0d ? 0d : (rightStatByClasses[1] / rightInfo.sumAll), (byte) rightInfo.indexOfLargestElement); double leftWeight = info.sumAll == 0d ? 0d : (leftInfo.sumAll / info.sumAll); double rightWeight = info.sumAll == 0d ? 0d : (rightInfo.sumAll / info.sumAll); double gain = info.impurity - leftWeight * leftInfo.impurity - rightWeight * rightInfo.impurity; if(gain <= minInfoGain) { continue; } Split split = null; if(config.isCategorical()) { if(index >= config.getBinCategory().size()) { // missing value bin, all missing value will be replaced by empty string in norm step leftCategories.add((short) (config.getBinCategory().size())); } else { leftCategories.add((short) index); } leftCategorySetSize += 1; boolean isLeft = true; Set<Short> rightCategories = null; if(config.getBinCategory().size() + 1 <= leftCategorySetSize * 2) { // too many in left, use right; isLeft = false; rightCategories = new SimpleBitSet<Short>(config.getBinCategory().size() + 1); for(short j = 0; j < (config.getBinCategory().size() + 1); j++) { if(!leftCategories.contains(j)) { rightCategories.add(j); } } } // new hash set to copy a new one avoid share object issue split = new Split(config.getColumnNum(), FeatureType.CATEGORICAL, 0d, isLeft, new SimpleBitSet<Short>( config.getBinCategory().size() + 1, (SimpleBitSet<Short>) (isLeft ? leftCategories : rightCategories))); } else { split = new Split(config.getColumnNum(), FeatureType.CONTINUOUS, config.getBinBoundary().get(index + 1), false, null); } internalGainList.add(new GainInfo(gain, info.impurity, predict, leftInfo.impurity, rightInfo.impurity, leftPredict, rightPredict, split, info.sumAll)); } return GainInfo.getGainInfoByMaxGain(internalGainList); } private List<Pair> getCategoricalOrderList(double[] stats, int binSize) { List<Pair> categoricalOrderList = new ArrayList<Pair>(binSize); for(int i = 0; i < binSize; i++) { // for entropy, use bin positive rate to sort double sum = stats[i * super.statsSize] + stats[i * super.statsSize + 1]; double binPredict = 0d; if(sum != 0d) { binPredict = stats[i * super.statsSize + 1] / sum; } categoricalOrderList.add(new Pair(i, binPredict)); } Collections.sort(categoricalOrderList, new Comparator<Pair>() { @Override public int compare(Pair o1, Pair o2) { return Double.valueOf(o1.value).compareTo(Double.valueOf(o2.value)); } }); return categoricalOrderList; } private InternalEntropyInfo getEntropyInterInfo(double[] statsByClasses) { double sumAll = 0; for(int i = 0; i < statsByClasses.length; i++) { sumAll += statsByClasses[i]; } double impurity = 0d; int indexOfLargestElement = -1; if(sumAll != 0d) { double maxElement = Double.MIN_VALUE; for(int i = 0; i < statsByClasses.length; i++) { double rate = statsByClasses[i] / sumAll; if(rate != 0d) { impurity -= rate * log2(rate); if(statsByClasses[i] > maxElement) { maxElement = statsByClasses[i]; indexOfLargestElement = i; } } } } return new InternalEntropyInfo(sumAll, indexOfLargestElement, impurity); } private static class InternalEntropyInfo { double sumAll; double indexOfLargestElement; double impurity; public InternalEntropyInfo(double sumAll, double indexOfLargestElement, double impurity) { this.sumAll = sumAll; this.indexOfLargestElement = indexOfLargestElement; this.impurity = impurity; } } @Override public void featureUpdate(double[] featuerStatistic, int binIndex, float label, float significance, float weight) { // label + 0.1f to avoid 0.99999f is converted to 0 featuerStatistic[binIndex * super.statsSize + (int) (label + 0.000001f)] += (significance * weight); } private double log2(double x) { return Math.log(x) / Math.log(2); } } /** * Gini impurity value for classification tree. Entropy formula is SUM(- rate * rate ) * * @author Zhang David (pengzhang@paypal.com) */ class Gini extends Impurity { public Gini(int numClasses, int minInstancesPerNode, double minInfoGain) { assert numClasses > 0; super.statsSize = numClasses; super.minInstancesPerNode = minInstancesPerNode; super.minInfoGain = minInfoGain; } @Override public GainInfo computeImpurity(double[] stats, ColumnConfig config) { int numClasses = super.statsSize; double[] statsByClasses = new double[numClasses]; for(int i = 0; i < stats.length / numClasses; i++) { for(int j = 0; j < numClasses; j++) { double oneStatValue = stats[i * super.statsSize + j]; statsByClasses[j] += oneStatValue; } } List<Pair> categoricalOrderList = null; if(config.isCategorical()) { // sort by predict and then pick the best split categoricalOrderList = getCategoricalOrderList(stats, stats.length / super.statsSize); } InternalGiniInfo info = getGiniInfo(statsByClasses); // prob only effective in binary classes Predict predict = new Predict(info.sumAll == 0d ? 0d : statsByClasses[1] / info.sumAll, (byte) info.indexOfLargestElement); double[] leftStatByClasses = new double[numClasses]; double[] rightStatByClasses = new double[numClasses]; List<GainInfo> internalGainList = new ArrayList<GainInfo>(); Set<Short> leftCategories = config.isCategorical() ? new SimpleBitSet<Short>(config.getBinCategory().size() + 1) : null; int leftCategorySetSize = 0; for(int i = 0; i < (stats.length / numClasses - 1); i++) { int index = i; if(config.isCategorical()) { index = categoricalOrderList.get(i).index; } for(int j = 0; j < leftStatByClasses.length; j++) { leftStatByClasses[j] += stats[index * numClasses + j]; } InternalGiniInfo leftInfo = getGiniInfo(leftStatByClasses); Predict leftPredict = new Predict(leftInfo.sumAll == 0d ? 0d : leftStatByClasses[1] / leftInfo.sumAll, (byte) leftInfo.indexOfLargestElement); for(int j = 0; j < leftStatByClasses.length; j++) { rightStatByClasses[j] = statsByClasses[j] - leftStatByClasses[j]; } InternalGiniInfo rightInfo = getGiniInfo(rightStatByClasses); if(leftInfo.sumAll <= minInstancesPerNode || rightInfo.sumAll <= minInstancesPerNode) { continue; } Predict rightPredict = new Predict(rightInfo.sumAll == 0d ? 0d : rightStatByClasses[1] / rightInfo.sumAll, (byte) rightInfo.indexOfLargestElement); double leftWeight = info.sumAll == 0d ? 0d : (leftInfo.sumAll / info.sumAll); double rightWeight = info.sumAll == 0d ? 0d : (rightInfo.sumAll / info.sumAll); double gain = info.impurity - leftWeight * leftInfo.impurity - rightWeight * rightInfo.impurity; if(gain <= minInfoGain) { continue; } Split split = null; if(config.isCategorical()) { // cast to short is safe as we limit max bin size to Short.MAX_VALUE while may be not good for scale if(index >= config.getBinCategory().size()) { // missing value bin, all missing value will be replaced by empty string in norm step leftCategories.add((short) (config.getBinCategory().size())); } else { leftCategories.add((short) index); } leftCategorySetSize += 1; boolean isLeft = true; Set<Short> rightCategories = null; if(config.getBinCategory().size() + 1 <= leftCategorySetSize * 2) { // too many in left, use right; isLeft = false; rightCategories = new SimpleBitSet<Short>(config.getBinCategory().size() + 1); for(short j = 0; j < (config.getBinCategory().size() + 1); j++) { if(!leftCategories.contains(j)) { rightCategories.add(j); } } } // new hash set to copy a new one avoid share object issue split = new Split(config.getColumnNum(), FeatureType.CATEGORICAL, 0d, isLeft, new SimpleBitSet<Short>( config.getBinCategory().size() + 1, (SimpleBitSet<Short>) (isLeft ? leftCategories : rightCategories))); } else { split = new Split(config.getColumnNum(), FeatureType.CONTINUOUS, config.getBinBoundary().get(index + 1), false, null); } internalGainList.add(new GainInfo(gain, info.impurity, predict, leftInfo.impurity, rightInfo.impurity, leftPredict, rightPredict, split, info.sumAll)); } return GainInfo.getGainInfoByMaxGain(internalGainList); } private List<Pair> getCategoricalOrderList(double[] stats, int binSize) { List<Pair> categoricalOrderList = new ArrayList<Pair>(binSize); for(int i = 0; i < binSize; i++) { // for gini, use bin positive rate to sort double sum = stats[i * super.statsSize] + stats[i * super.statsSize + 1]; double binPredict = 0d; if(sum != 0d) { binPredict = stats[i * super.statsSize + 1] / sum; } categoricalOrderList.add(new Pair(i, binPredict)); } Collections.sort(categoricalOrderList, new Comparator<Pair>() { @Override public int compare(Pair o1, Pair o2) { return Double.valueOf(o1.value).compareTo(Double.valueOf(o2.value)); } }); return categoricalOrderList; } private InternalGiniInfo getGiniInfo(double[] statsByClasses) { double sumAll = 0; for(int i = 0; i < statsByClasses.length; i++) { sumAll += statsByClasses[i]; } double impurity = 0d; int indexOfLargestElement = -1; if(sumAll != 0d) { double maxElement = Double.MIN_VALUE; for(int i = 0; i < statsByClasses.length; i++) { double rate = statsByClasses[i] / sumAll; impurity -= rate * rate; if(statsByClasses[i] > maxElement) { maxElement = statsByClasses[i]; indexOfLargestElement = i; } } } return new InternalGiniInfo(sumAll, indexOfLargestElement, impurity); } private static class InternalGiniInfo { double sumAll; double indexOfLargestElement; double impurity; public InternalGiniInfo(double sumAll, double indexOfLargestElement, double impurity) { this.sumAll = sumAll; this.indexOfLargestElement = indexOfLargestElement; this.impurity = impurity; } } @Override public void featureUpdate(double[] featuerStatistic, int binIndex, float label, float significance, float weight) { // label + 0.1f to avoid 0.99999f is converted to 0 featuerStatistic[binIndex * super.statsSize + (int) (label + 0.000001f)] += (significance * weight); } } class Pair { public Pair(int index, double value) { this.index = index; this.value = value; } int index; double value; }