package hex.singlenoderf;
import hex.singlenoderf.Data.Row;
import water.util.Utils;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Random;
/** Keeps track of the column distributions and analyzes the column splits in the
* end producing the single split that will be used for the node. */
abstract class Statistic {
/** Column distributions: column x arity x classes
* Remembers the number of rows of the given column index, encodedValue, class. */
protected final int[][][] _columnDists;
protected final int[] _features; // Columns/features that are currently used.
protected Random _random; // Pseudo random number generator
private long _seed; // Seed for prng
private HashSet<Integer> _remembered; // Features already used
final double[] _classWt; // Class weights
private int _exclusiveSplitLimit;
protected final int[/*num_features*/][/*column_bins*/][/*response_bins*/] _columnDistsRegression;
boolean _regression;
/** Returns the best split for a given column */
protected abstract Split ltSplit(int colIndex, Data d, int[] dist, int distWeight, Random rand);
protected abstract Split eqSplit(int colIndex, Data d, int[] dist, int distWeight, Random rand);
protected abstract Split ltSplit(int colIndex, Data d, float[] dist, float distWeight, Random rand);
protected abstract Split eqSplit(int colIndex, Data d, float[] dist, float distWeight, Random rand);
/** Split descriptor for a particular column.
* Holds the column name and the split point, which is the last column class
* that will go to the left tree. If the column index is -1 then the split
* value indicates the return value of the node.
*/
static class Split {
final int _column, _split;
final float _splitReg;
final double _fitness;
Split(int column, int split, double fitness) {
_column = column; _split = split; _fitness = fitness; _splitReg = -1.f;
}
/** A constant split used for true leaf nodes where all rows are of the same class. */
static Split constant(int result) { return new Split(-1, result, -1); }
/** An impossible split, which behaves like a constant split. However impossible split
* occurs when there are different row classes present, but they all have
* the same column value and therefore no split can be made.
*/
static Split impossible(int result) { return new Split(-2, result, -1); }
/** Classic split. All lower or equal than split value go left, all greater go right. */
static Split split(int column, int split, double fitness) { return new Split(column, split,fitness); }
/** Return an impossible split that has the best fitness */
static Split defaultSplit() { return new Split(-2,0,-Double.MAX_VALUE); }
/** Exclusion split. All equal to split value go left, all different go right. */
static Split exclusion(int column, int split, double fitness) { return new ExclusionSplit(column,split,fitness); }
final boolean isLeafNode() { return _column < 0; }
final boolean isConstant() { return _column == -1; }
final boolean isImpossible() { return _column == -2; }
final boolean betterThan(Split other) { return _fitness > other._fitness; }
final boolean isExclusion() { return this instanceof ExclusionSplit; }
}
/** An exclusion split. */
static class ExclusionSplit extends Split {
ExclusionSplit(int column, int split, double fitness) { super(column, split,fitness); }
}
/** Aggregates the given column's distribution to the provided array and
* returns the sum of weights of that array. */
private int aggregateColumn(int colIndex, int[] dist) {
int sum = 0;
for (int j = 0; j < _columnDists[colIndex].length; ++j) {
for (int i = 0; i < dist.length; ++i) {
int tmp = _columnDists[colIndex][j][i];
sum += tmp;
dist[i] += tmp;
}
}
return sum;
}
/** Sum up the target responses and return that value (this will be the unweighted "mean"
*
* @param colIndex: The column we're summing over
* @param dist: The *raw* response value for each bin.
* @return The unweighted mean.
*/
private float aggregateColumn(int colIndex, float[] dist) {
float sum = 0.f;
for (int j = 0; j < _columnDistsRegression[colIndex].length; ++j) {
for (int i = 0; i < dist.length; ++i) {
float tmp = _columnDistsRegression[colIndex][j][i];
sum += tmp;
dist[i] += tmp;
}
}
return sum;
}
Statistic(Data data, int featuresPerSplit, long seed, int exclusiveSplitLimit, boolean regression) {
_columnDistsRegression = new int[data.columns() - 1][][];
_columnDists = new int[data.columns()-1][][];
_regression = regression;
if (!regression) {
_random = Utils.getRNG(seed);
// first create the column distributions
for (int i = 0; i < _columnDists.length; ++i)
if (!data.isIgnored(i))
_columnDists[i] = new int[data.columnArity(i)+1][data.classes()];
// create the columns themselves
_features = new int[featuresPerSplit];
_remembered = null;
_classWt = data.classWt(); // Class weights
_exclusiveSplitLimit = exclusiveSplitLimit;
} else {
_random = Utils.getRNG(seed);
for (int i = 0; i < _columnDistsRegression.length; ++i)
if(!data.isIgnored(i)) {
DataAdapter.Col c = data._dapt._c[i];
int colBins = c._isByte ? Utils.maxValue(c._rawB) : c._binned.length;
_columnDistsRegression[i] = new int[colBins + 1][ data.columnArityOfClassCol()];
}
_features = new int[featuresPerSplit];
_remembered = null;
_classWt = data.classWt();
_exclusiveSplitLimit = exclusiveSplitLimit;
}
}
/** Remember features used for this split so we can grab different features
* and avoid these useless ones. Returns false if no more features are left. */
boolean rememberFeatures(Data data) {
if( _remembered == null ) _remembered = new HashSet<Integer>();
for(int f : _features) if ( f != -1 ) _remembered.add(f);
for(int i=0;i<data.columns()-1;i++) if(isColumnUsable(data,i)) return true;
return false;
}
/**We are done with this particular split and can forget the features we have
* used to compute it.*/
void forgetFeatures() { _remembered = null; }
/**Features can be used in a split if they are not already used. */
private boolean isColumnUsable(Data d, int i) {
assert i < d.columns()-1; // Last column is class
return !d.isIgnored(i) && (_remembered == null || !_remembered.contains(i)) && d.colMaxIdx(i) != d.colMinIdx(i);
}
/** Resets the statistic for the next split. Pick a subset of the features and zero out
* distributions. Implementation uses reservoir sampling (http://en.wikipedia.org/wiki/Reservoir_sampling)
* to select features. Features that (a) have been marked as ignore, (b) that have already been
* tried at this split, (c) the class feature, will not be selected. */
void reset(Data data, long seed, boolean regression) {
if (!regression) {
_random = Utils.getRNG(_seed = seed);
int i = 0, j = 0, featuresPerSplit = _features.length;
Arrays.fill(_features, -1);
for( ; j < featuresPerSplit && i < data.columns()-1; i++) if (isColumnUsable(data, i)) _features[j++] = i;
for( ; i < data.columns()-1; i++ ) {
if( !isColumnUsable(data, i) ) continue;
int k = _random.nextInt(j+1); // Reservoir sampling: take a random number in the interval [0,index] (inclusive)
if( k < featuresPerSplit ) _features[k] = i;
j++;
}
for( int f : _features) if (f != -1) for( int[] d: _columnDists[f]) Arrays.fill(d,0); // reset the column distributions
} else {
_random = Utils.getRNG(_seed = seed);
int i = 0, j = 0, featuresPerSplit = _features.length;
Arrays.fill(_features, -1);
for( ; j < featuresPerSplit && i < data.columns()-1; i++) if (isColumnUsable(data, i)) _features[j++] = i;
for( ; i < data.columns()-1; i++ ) {
if( !isColumnUsable(data, i) ) continue;
int k = _random.nextInt(j+1); // Reservoir sampling: take a random number in the interval [0,index] (inclusive)
if( k < featuresPerSplit ) _features[k] = i;
j++;
}
for( int f : _features) if (f != -1) for( int[] d: _columnDistsRegression[f]) Arrays.fill(d,0); // reset the column distributions
}
}
/** Adds the given row to the statistic. Updates the column distributions for
* the analyzed columns. */
void addQ(Row row, boolean regression) {
final int cls = row.classOf(); //regression ? -1 : row.classOf();
for (int f : _features)
if ( f != -1) {
if (row.isValid() && row.hasValidValue(f)) {
if (!regression) {
short val = row.getEncodedColumnValue(f);
_columnDists[f][val][cls]++;
} else {
short val = row.getEncodedColumnValue(f);
if (val == DataAdapter.BAD) continue;
int resp = row.getEncodedClassColumnValue();
if (resp == DataAdapter.BAD) continue;
// short val2 = row.getEncodedClassColumnValue();
_columnDistsRegression[f][val][resp]++; // = row.getRawClassColumnValueFromBin();
}
}
}
}
/** Adds the given row to the statistic. Updates the column distributions for
* the analyzed columns. This version knows the row is always valid (always
* has a valid class), and is hand-inlined. */
// void addQValid( final int cls, final int ridx, final DataAdapter.Col cs[]) {
// for (int f : _features) {
// if (f == -1) break;
// short[] bins = cs[f]._binned; // null if byte col, otherwise bin#
// int val;
// if (bins != null) { // binned?
// val = bins[ridx]; // Grab bin#
// if (val == DataAdapter.BAD) continue; // ignore bad rows
// } else { // not binned?
// val = (0xFF & cs[f]._rawB[ridx]); // raw byte value, has no bad rows
// }
// _columnDists[f][val][cls]++;
// }
// }
/** Apply any class weights to the distributions.*/
void applyClassWeights() {
if( _classWt == null ) return;
if (_regression) return;
for( int f : _features ) // For all columns, get the distribution
if ( f != -1)
for( int[] clss : _columnDists[f] ) // For all distributions, get the class distribution
for( int cls=0; cls<clss.length; cls++ )
clss[cls] = (int)(clss[cls]*_classWt[cls]); // Scale by the class weights
}
/** Calculates the best split and returns it. The split can be either a split
* which is a node where all rows with given column value smaller or equal to
* the split value will go to the left and all greater will go to the right.
* Or it can be an exclusion split, where all rows with column value equal to
* split value go to the left and all others go to the right.
*/
Split split(Data d, boolean expectLeaf) {
if(!_regression) {
int[] dist = new int[d.classes()];
boolean valid = false;
for(int f : _features) valid |= f != -1;
if (!valid) return Split.defaultSplit(); // there are no features left...
int distWeight = aggregateColumn(_features[0], dist); // initialize the distribution array
int m = Utils.maxIndex(dist, _random);
if( expectLeaf || (dist[m] == distWeight )) return Split.constant(m); // check if we are leaf node
Split bestSplit = Split.defaultSplit();
for( int f : _features ) { // try the splits
if ( f == -1 ) continue;
Split s = pickAndSplit(d,f, dist, distWeight, _random);
if( s.betterThan(bestSplit) ) bestSplit = s;
}
if( !bestSplit.isImpossible() ) return bestSplit;
if( !rememberFeatures(d) ) return bestSplit; // Enough features to try again?
reset(d,_seed+(1L<<16), _regression); // Reset with new features
for(Row r: d) addQ(r, _regression); // Reload the distributions
applyClassWeights(); // Weight the distributions
return split(d,expectLeaf);
} else {
float[] dist = new float[d.columnArityOfClassCol()];
boolean valid = false;
for(int f: _features) valid |= f != -1;
if(!valid) return Split.defaultSplit();
float unweightedMean = aggregateColumn(_features[0], dist);
int m = Utils.maxIndex(dist, _random);
if(expectLeaf || (dist[m] == unweightedMean)) return Split.constant(m);
Split bestSplit = Split.defaultSplit();
for (int f: _features) {
if (f == -1) continue;
Split s = pickAndSplit(d,f,dist,unweightedMean,_random);
if (s.betterThan(bestSplit)) bestSplit = s;
}
if (!bestSplit.isImpossible()) return bestSplit;
if (!rememberFeatures(d)) return bestSplit;
reset(d, _seed+(1L<<16), _regression);
for(Row r: d) addQ(r,_regression);
return split(d, expectLeaf);
}
}
private Split pickAndSplit(Data d, int col, int[] dist, int distWeight, Random rand) {
boolean isBool = d.columnArity(col) == 1; //screwed up api, 1 means 2.
boolean isBig = d.columnArity(col) > _exclusiveSplitLimit;
boolean isFloat = d.isFloat(col);
if (isBool) return eqSplit(col,d,dist,distWeight,_random);
else if (isBig || isFloat) return ltSplit(col,d, dist, distWeight, _random);
else {
Split s1 = eqSplit(col,d,dist,distWeight,_random);
if (s1.isImpossible()) return s1;
Split s2 = ltSplit(col,d, dist, distWeight, _random);
return s1.betterThan(s2) ? s1 : s2;
}
}
private Split pickAndSplit(Data d, int col, float[] dist, float distWeight, Random rand) {
boolean isBool = d.columnArity(col) == 1; //screwed up api, 1 means 2.
boolean isBig = d.columnArity(col) > _exclusiveSplitLimit;
boolean isFloat = d.isFloat(col);
if (isBool) return eqSplit(col,d,dist,distWeight,_random);
else if (isBig || isFloat) return ltSplit(col,d, dist, distWeight, _random);
else {
Split s1 = eqSplit(col,d,dist,distWeight,_random);
if (s1.isImpossible()) return s1;
Split s2 = ltSplit(col,d, dist, distWeight, _random);
return s1.betterThan(s2) ? s1 : s2;
}
}
}