package hex.singlenoderf;
import hex.singlenoderf.*;
import hex.singlenoderf.Data;
import water.util.Utils;
import java.util.Random;
/**The entropy formula is the classic Shannon entropy gain, which is:
*
* - \sum(p_i * log2(_pi))
*
* where p_i is the probability of i-th class occurring. The entropy is
* calculated for the left and right node after the given split and they are
* combined together weighted on their probability.
*
* ent left * weight left + ent right * weight right
* --------------------------------------------------
* total weight
*
* And to get the gain, this is subtracted from potential maximum of 1
* simulating the previous node. The biggest gain is selected as the tree split.
*
* The same is calculated also for exclusion, where left stands for the rows
* where column equals to the split point and right stands for all others.
*/
class EntropyStatistic extends Statistic {
public EntropyStatistic(Data data, int features, long seed, int exclusiveSplitLimit) { super(data, features, seed, exclusiveSplitLimit, false /*classification*/); }
/** LessThenEqual splits s*/
@Override protected Split ltSplit(int col, Data d, int[] dist, int distWeight, Random rand) {
final int[] distL = new int[d.classes()], distR = dist.clone();
final double upperBoundReduction = upperBoundReduction(d.classes());
double maxReduction = -1;
int bestSplit = -1;
int totL = 0, totR = 0; // Totals in the distribution
int classL = 0, classR = 0; // Count of non-zero classes in the left/right distributions
for (int e: distR) { // All zeros for the left, but need to compute for the right
totR += e;
if( e != 0 ) classR++;
}
// For this one column, look at all his split points and find the one with the best gain.
for (int i = 0; i < _columnDists[col].length - 1; ++i) {
int [] cdis = _columnDists[col][i];
for (int j = 0; j < distL.length; ++j) {
int v = cdis[j];
if( v == 0 ) continue; // No rows with this class
totL += v; totR -= v;
if( distL[j]== 0 ) classL++; // One-time transit from zero to non-zero for class j
distL[j] += v; distR[j] -= v;
if( distR[j]== 0 ) classR--; // One-time transit from non-zero to zero for class j
}
if (totL == 0) continue; // Totals are zero ==> this will not actually split anything
if (totR == 0) continue; // Totals are zero ==> this will not actually split anything
// Compute gain.
// If the distribution has only 1 class, the gain will be zero.
double eL = 0, eR = 0;
if( classL > 1 ) for (int e: distL) eL += gain(e,totL);
if( classR > 1 ) for (int e: distR) eR += gain(e,totR);
double eReduction = upperBoundReduction - ( (eL * totL + eR * totR) / (totL + totR) );
if (eReduction == maxReduction) {
// For now, don't break ties. Most ties are because we have several
// splits with NO GAIN. This happens *billions* of times in a standard
// covtype RF, because we have >100K leaves per tree (and 50 trees and
// 54 columns per leave and however many bins per column), and most
// leaves have no gain at most split points.
//if (rand.nextInt(10)<2) bestSplit=i;
} else if (eReduction > maxReduction) {
bestSplit = i; maxReduction = eReduction;
}
}
return bestSplit == -1
? Split.impossible(Utils.maxIndex(dist,_random))
: Split.split(col,bestSplit,maxReduction);
}
/**Gain function*/
private double gain(int e, int tot) {
if (e == 0) return 0;
double v = e/(double)tot;
double r = v * Math.log(v) / log2;
return -r;
}
/**Maximal entropy*/
private double upperBoundReduction(double classes) {
double p = 1/classes;
double r = p * Math.log(p)/log2 * classes;
return -r;
}
/**Compute an exclusive split (i.e. 'feature' '==' 'val') */
@Override protected Split eqSplit(int col, Data d, int[] dist, int distWeight, Random rand) {
final int[] distR = new int[d.classes()], distL = dist.clone();
final double upperBoundReduction = upperBoundReduction(d.classes());
double maxReduction = -1;
int bestSplit = -1;
int min = d.colMinIdx(col);
int max = d.colMaxIdx(col);
for (int i = min; i < max+1; ++i) {
for (int j = 0; j < distR.length; ++j) {
int v = _columnDists[col][i][j];
distL[j] += distR[j];
distR[j] = v;
distL[j] -= v;
}
int totL = 0, totR = 0;
for (int e: distL) totL += e;
if (totL == 0) continue;
for (int e: distR) totR += e;
if (totR == 0) continue;
double eL = 0, eR = 0;
for (int e: distL) eL += gain(e,totL);
for (int e: distR) eR += gain(e,totR);
double eReduction = upperBoundReduction - ( (eL * totL + eR * totR) / (totL + totR) );
if (eReduction == maxReduction){
if (rand.nextInt(10)<2) bestSplit=i; // randomly pick one out of several
} else if (eReduction > maxReduction) {
bestSplit = i; maxReduction = eReduction;
}
if (i==0 && d.columnArity(col) == 1) break; // for boolean features, only one split needs to be evaluated
}
return bestSplit == -1
? Split.impossible(Utils.maxIndex(dist,_random))
: Split.exclusion(col,bestSplit,maxReduction);
}
@Override
protected Split ltSplit(int colIndex, Data d, float[] dist, float distWeight, Random rand) {
return null; //not called for classification
}
@Override
protected Split eqSplit(int colIndex, Data d, float[] dist, float distWeight, Random rand) {
return null; //not called for classification
}
static final double log2 = Math.log(2);
}