package hex.singlenoderf;
import water.util.Utils;
import java.util.Random;
/** Computes the gini split statistics.
*
* The Gini fitness is calculated as a probability that the element will be
* misclassified, which is:
*
* 1 - \sum(p_i^2)
*
* This is computed for the left and right subtrees and added together:
*
* gini left * weight left + gini right * weight left
* --------------------------------------------------
* weight total
*
* And subtracted from an ideal worst 1 to simulate the gain from previous node.
* The best gain is then selected. Same is done for exclusions, where again
* left stands for the rows with column value equal to the split value and
* right for all different ones.
*/
public class GiniStatistic extends Statistic {
public GiniStatistic(Data data, int features, long seed, int exclusiveSplitLimit) { super(data, features, seed, exclusiveSplitLimit, false /*classification*/); }
private double gini(int[] dd, int sum) {
double result = 1.0;
double sd = (double)sum;
for (int d : dd) {
double tmp = ((double)d)/sd;
result -= tmp*tmp;
}
return result;
}
@Override protected Split ltSplit(int col, Data d, int[] dist, int distWeight, Random _) {
int[] leftDist = new int[d.classes()];
int[] riteDist = dist.clone();
int lW = 0;
int rW = distWeight;
double totWeight = rW;
// we are not a single class, calculate the best split for the column
int bestSplit = -1;
double bestFitness = 0.0;
assert leftDist.length==_columnDists[col][0].length;
for (int i = 0; i < _columnDists[col].length-1; ++i) {
// first copy the i-th guys from rite to left
for (int j = 0; j < leftDist.length; ++j) {
int t = _columnDists[col][i][j];
lW += t;
rW -= t;
leftDist[j] += t;
riteDist[j] -= t;
}
// now make sure we have something to split
if( lW == 0 || rW == 0 ) continue;
double f = 1.0 -
(gini(leftDist,lW) * ((double)lW / totWeight) +
gini(riteDist,rW) * ((double)rW / totWeight));
if( f>bestFitness ) { // Take split with largest fitness
bestSplit = i;
bestFitness = f;
}
}
return bestSplit == -1
? Split.impossible(Utils.maxIndex(dist, _random))
: Split.split(col, bestSplit, bestFitness);
}
@Override protected Split eqSplit(int colIndex, Data d, int[] dist, int distWeight, Random _) {
int[] inclDist = new int[d.classes()];
int[] exclDist = dist.clone();
// we are not a single class, calculate the best split for the column
int bestSplit = -1;
double bestFitness = 0.0; // Fitness to maximize
for( int i = 0; i < _columnDists[colIndex].length-1; ++i ) {
// first copy the i-th guys from rite to left
int sumt = 0;
for( int j = 0; j < inclDist.length; ++j ) {
int t = _columnDists[colIndex][i][j];
sumt += t;
inclDist[j] = t;
exclDist[j] = dist[j] - t;
}
int inclW = sumt;
int exclW = distWeight - inclW;
// now make sure we have something to split
if( inclW == 0 || exclW == 0 ) continue;
double f = 1.0 -
(gini(inclDist,inclW) * ((double)inclW / distWeight) +
gini(exclDist,exclW) * ((double)exclW / distWeight));
if( f>bestFitness ) { // Take split with largest fitness
bestSplit = i;
bestFitness = f;
}
}
return bestSplit == -1
? Split.impossible(Utils.maxIndex(dist, _random))
: Split.exclusion(colIndex, bestSplit, bestFitness);
}
@Override
protected Split ltSplit(int colIndex, Data d, float[] dist, float distWeight, Random rand) {
return null; //not called for classification
}
@Override
protected Split eqSplit(int colIndex, Data d, float[] dist, float distWeight, Random rand) {
return null; //not called for classification
}
}