package hex.singlenoderf;
import water.util.Utils;
import java.util.Random;
/** Computes the mse split statistics.
*
* For regression: Try to minimize the squared error at each split.
*/
public class MSEStatistic extends Statistic {
public MSEStatistic(Data data, int features, long seed, int exclusiveSplitLimit) {
super(data, features, seed, exclusiveSplitLimit, true /*regression*/);
}
private float computeAv(float[] dist, Data d, int sum) {
float res = 0f;
for (int i = 0; i < dist.length; ++i) {
int tmp = (int) dist[i];
res += d._dapt._c[d._dapt._c.length - 1]._binned2raw[i] * tmp;
}
return sum == 0 ? Float.POSITIVE_INFINITY : res / (float) sum;
}
private float[] computeDist(Data d, int colIndex) {
float[] res = new float[d.columnArityOfClassCol()];
for (int i = 0; i < _columnDistsRegression[colIndex].length - 1; ++i) {
for (int j = 0; j < _columnDistsRegression[colIndex][i].length - 1; ++j) {
res[j] += _columnDistsRegression[colIndex][i][j];
}
}
return res;
}
@Override
protected Split ltSplit(int colIndex, Data d, float[] dist, float distWeight, Random rand) {
float bestSoFar = Float.POSITIVE_INFINITY;
int bestSplit = -1;
int lW = 0;
int rW = d.rows();
float[] leftDist = new float[d.columnArityOfClassCol()];
float[] riteDist = computeDist(d, colIndex); //dist.clone();
for (int j = 0; j < _columnDistsRegression[colIndex].length - 1; ++j) {
for (int i = 0; i < dist.length; ++i) {
int t = _columnDistsRegression[colIndex][j][i];
lW += t;
rW -= t;
leftDist[i] += t;
riteDist[i] -= t;
}
float Y_R = computeAv(riteDist, d, rW);
float Y_L = computeAv(leftDist, d, lW);
float newSplitValue = Y_R + Y_L;
if (newSplitValue < bestSoFar) {
bestSoFar = newSplitValue;
bestSplit = j;
}
}
return (bestSplit == -1 || bestSoFar == Float.POSITIVE_INFINITY)
? Split.impossible(Utils.maxIndex(computeDist(d, colIndex), _random))
: Split.split(colIndex, bestSplit, bestSoFar);
}
@Override
protected Split eqSplit(int colIndex, Data d, float[] dist, float distWeight, Random rand) {
// we are not a single class, calculate the best split for the column
int bestSplit = -1;
float bestSoFar = 0.f; // Fitness to maximize
for( int i = 0; i < _columnDists[colIndex].length-1; ++i ) {
float Y_incl = 0.f;
float Y_excl = distWeight;
int nobs_incl = 0;
int nobs_excl = d.rows();
for (float aDist : dist) {
Y_incl += aDist;
Y_excl -= aDist;
nobs_incl++;
nobs_excl--;
float newSplitValue = (Y_incl * Y_incl / (float) nobs_incl) + (Y_excl * Y_excl / (float) nobs_excl);
if (newSplitValue > bestSoFar) {
bestSoFar = newSplitValue;
bestSplit = i;
}
}
}
return bestSplit == -1
? Split.impossible(Utils.maxIndex(dist, _random))
: Split.exclusion(colIndex, bestSplit, bestSoFar);
}
@Override protected Split ltSplit(int col, Data d, int[] dist, int distWeight, Random r) {
return null; //not called for regression
}
@Override protected Split eqSplit(int colIndex, Data d, int[] dist, int distWeight, Random r) {
return null; //not called for regression
}
}