package hex.singlenoderf;
import water.util.Utils;
import java.util.Random;
/** Computes the twoing split statistic.
*
* The decrease in (twoing) impurity as the result of a given split is
* computed as follows:
*
* 1 weight left weight right
* - * ------------ * ------------- * twoing( left, right )
* 4 weight total weight total
*
* twoing( left, right ) = (\sum(|p_i(left) - p_i(right)|)^2, where
* p_i( left ) is the fraction of observations in the left node of class i
* p_i( right ) is the fraction of observations in the right node of class i
*
* The split that produces the largest decrease in impurity is selected.
* Same is done for exclusions, where again left stands for the rows with column
* value equal to the split value and right for all different ones.
*
* ece 11/14
*/
public class TwoingStatistic extends Statistic {
public TwoingStatistic(Data data, int features, long seed, int exclusiveSplitLimit) { super(data, features, seed, exclusiveSplitLimit, false /*classification*/); }
private double twoing(int[] dd_l, int sum_l, int[] dd_r, int sum_r ) {
double result = 0.0;
double sd_l = (double)sum_l;
double sd_r = (double)sum_r;
for (int i = 0; i < dd_l.length; i++) {
double tmp = Math.abs(((double)dd_l[i])/sd_l - ((double)dd_r[i])/sd_r);
result = result + tmp;
}
result = result * result;
return result;
}
@Override protected Split ltSplit(int col, Data d, int[] dist, int distWeight, Random _) {
int[] leftDist = new int[d.classes()];
int[] riteDist = dist.clone();
int lW = 0;
int rW = distWeight;
double totWeight = rW;
// we are not a single class, calculate the best split for the column
int bestSplit = -1;
double bestFitness = 0.0;
assert leftDist.length==_columnDists[col][0].length;
for (int i = 0; i < _columnDists[col].length-1; ++i) {
// first copy the i-th guys from rite to left
for (int j = 0; j < leftDist.length; ++j) {
int t = _columnDists[col][i][j];
lW += t;
rW -= t;
leftDist[j] += t;
riteDist[j] -= t;
}
// now make sure we have something to split
if( lW == 0 || rW == 0 ) continue;
double f = 0.25 * ((double)lW / totWeight) * ((double)rW / totWeight) *
twoing(leftDist, lW, riteDist, rW);
if( f>bestFitness ) { // Take split with largest fitness
bestSplit = i;
bestFitness = f;
}
}
return bestSplit == -1
? Split.impossible(Utils.maxIndex(dist, _random))
: Split.split(col, bestSplit, bestFitness);
}
@Override protected Split eqSplit(int colIndex, Data d, int[] dist, int distWeight, Random _) {
int[] inclDist = new int[d.classes()];
int[] exclDist = dist.clone();
// we are not a single class, calculate the best split for the column
int bestSplit = -1;
double bestFitness = 0.0; // Fitness to maximize
for( int i = 0; i < _columnDists[colIndex].length-1; ++i ) {
// first copy the i-th guys from rite to left
int sumt = 0;
for( int j = 0; j < inclDist.length; ++j ) {
int t = _columnDists[colIndex][i][j];
sumt += t;
inclDist[j] = t;
exclDist[j] = dist[j] - t;
}
int inclW = sumt;
int exclW = distWeight - inclW;
// now make sure we have something to split
if( inclW == 0 || exclW == 0 ) continue;
double f = ((double)inclW / distWeight) * ((double)exclW / distWeight) *
twoing(inclDist, inclW, exclDist, exclW);
if( f>bestFitness ) { // Take split with largest fitness
bestSplit = i;
bestFitness = f;
}
}
return bestSplit == -1
? Split.impossible(Utils.maxIndex(dist, _random))
: Split.exclusion(colIndex, bestSplit, bestFitness);
}
@Override
protected Split ltSplit(int colIndex, Data d, float[] dist, float distWeight, Random rand) {
return null; //not called for classification
}
@Override
protected Split eqSplit(int colIndex, Data d, float[] dist, float distWeight, Random rand) {
return null; //not called for classification
}
}