package hex.singlenoderf;
import water.MemoryManager;
import water.util.Utils;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;
public class Data implements Iterable<Data.Row> {
/** Use stratified sampling */
boolean _stratify;
/** Random generator to make decision about missing data. */
final Random _rng;
public final class Row {
int _index;
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(_index).append(" [").append(classOf()).append("]:");
for( int i = 0; i < _dapt.columns(); ++i ) sb.append(_dapt.hasBadValue(_index, i) ? "NA" : _dapt.getEncodedColumnValue(_index, i)).append(',');
return sb.toString();
}
public int classOf() { return _dapt.classOf(_index); }
public final short getEncodedColumnValue(int colIndex) {
return _dapt.getEncodedColumnValue(_index, colIndex); }
public final short getEncodedClassColumnValue() {
return _dapt.getEncodedClassColumnValue(_index);
}
public final float getRawClassColumnValueFromBin() {
return _dapt.getRawClassColumnValueFromBin(_index);
}
public final boolean hasValidValue(int colIndex) { return !_dapt.hasBadValue(_index, colIndex); }
public final boolean isValid() { return !_dapt.isBadRow(_index); }
public final boolean isValidRaw() { return !_dapt.isBadRowRaw(_index); }
public final double getRawColumnValue(int colIndex) { return _dapt.getRawColumnValue(_index, colIndex); }
}
protected final DataAdapter _dapt;
/** Returns new Data object that stores all adapter's rows unchanged. */
public static Data make(DataAdapter da) { return new Data(da); }
protected Data(DataAdapter dapt) {
_dapt = dapt;
_rng = Utils.getDeterRNG(0x7b85dfe19122f0d5L);
_columnInfo = new ColumnInfo[_dapt.columns()];
for(int i = 0; i<_columnInfo.length; i++)
_columnInfo[i] = new ColumnInfo(i);
}
protected int start() { return 0; }
protected int end() { return _dapt._numRows; }
public final int rows() { return end() - start(); }
public final int columns() { return _dapt.columns(); }
public final int classes() { return _dapt.classes(); }
public final long seed() { return _dapt.seed(); }
public final String colName(int i) { return _dapt.columnName(i); }
public final float unmap(int col, int split) { return _dapt.unmap(col, split); }
public final int columnArity(int colIndex) { return _dapt.columnArity(colIndex); }
public final int columnArityOfClassCol() { return _dapt.columnArityOfClassCol(); }
/** Transforms given binned index (short) into 0..N-1 corresponding to predictor class */
public final int unmapClass(int clazz) {return _dapt.unmapClass(clazz); }
public final boolean isFloat(int col) { return _dapt.isFloat(col); }
public final double[] classWt() { return _dapt._classWt; }
public final boolean isIgnored(int col) { return _dapt.isIgnored(col); }
public final float computeAverage() {
float av = 0.f;
int nobs = 0;
for (Row r: this) {
if (r.isValid()) {
av += r.getRawClassColumnValueFromBin();
}
nobs++;
}
return nobs == 0 ? 0 : av / (float)(nobs);
}
public double[] unpackRow(Row r) {
double[] res = new double[_dapt._c.length-1];
for (int i = 0; i < _dapt._c.length-1; ++i) res[i] = r.getRawColumnValue(i);
return res;
}
public Row at(int i) { Row _r = new Row(); _r._index = permute(i); return _r;}
public final Iterator<Row> iterator() { return new RowIter(start(), end()); }
private class RowIter implements Iterator<Row> {
final Row _r = new Row();
int _pos = 0; final int _end;
public RowIter(int start, int end) { _pos = start; _end = end; }
public boolean hasNext() { return _pos < _end; }
public Row next() { _r._index = permute(_pos++); return _r; }
public void remove() { throw new RuntimeException("Unsupported"); }
}
// ----------------------
private int filterInv(Tree.SplitNode node, int[] permutation, Statistic ls, Statistic rs) {
final Row row = new Row();
int l = start(), r = end() - 1;
while (l <= r) {
int permIdx = row._index = permutation[l];
boolean putToLeft;
if (node.canDecideAbout(row)) { // are we splitting over existing value
putToLeft = node.isIn(row);
} else { // make a random choice about non
putToLeft = _rng.nextBoolean();
}
if (putToLeft) {
ls.addQ(row, ls._regression);
++l;
} else {
rs.addQ(row, rs._regression);
permutation[l] = permutation[r];
permutation[r--] = permIdx;
}
}
return l;
}
public long[] nonOOB() {
ArrayList<Integer> res = new ArrayList<Integer>();
for (Row r : this) res.add(r._index);
long[] rr = new long[res.size()];
for (int i = 0; i < rr.length; ++i) rr[i] = res.get(i);
return rr;
}
// Filter a column, with all valid data. i.e., skip the invalid check
private int filterVal(Tree.SplitNode node, int[] permutation, Statistic ls, Statistic rs) {
final int l =filterVal1(node,permutation);
filterVal3(permutation,ls,start(),l);
filterVal3(permutation,rs,l,end());
return l;
}
// Hand-inlining for performance... CNC
private int filterVal1(Tree.SplitNode node, int[] permutation) {
int cidx = node._column; // Decision column guiding the split
DataAdapter.Col cs[] = _dapt._c;
short bins[] = cs[cidx]._binned; // Bin#'s for each row
byte binb[] = cs[cidx]._rawB; // Bin#'s for each row
int split = node._split; // Value to split on
// Move the data into left/right halves
int l = start(), r = end() - 1;
while (l <= r) {
int permIdx = permutation[l];
int val = bins==null ? (0xFF&binb[permIdx]) : bins[permIdx];
if( val <= split ) {
++l;
} else {
permutation[l] = permutation[r];
permutation[r--] = permIdx;
}
}
return l;
}
// Update the histogram
private void filterVal3(int[] permutation, Statistic s, final int lo, final int hi) {
if (!s._regression) {
DataAdapter.Col cs[] = _dapt._c;
short classs[]= cs[_dapt.classColIdx()]._binned;
int cds[][][] = s._columnDists;
int fs[] = s._features;
// Run this loop by-feature instead of by-row - so that the updates in the
// inner loops do not need to start from loading the feature array.
for (int f : fs) {
if (f == -1) break; // Short features.
int cdsf[][] = cds[f]; // Histogram per-column (by value & class)
short[] bins = cs[f]._binned; // null if byte col, otherwise bin#
if (bins != null) { // binned?
for (int i = lo; i < hi; i++) { // Binned-loop
int permIdx = permutation[i]; // Get the row
int val = bins[permIdx]; // Bin-for-row
if (val == DataAdapter.BAD) continue; // ignore bad rows
int cls = classs[permIdx]; // Class-for-row
if (cls == DataAdapter.BAD) continue; // ignore rows with NA in response column
cdsf[val][cls]++; // Bump histogram
}
} else { // not binned?
byte[] raw = cs[f]._rawB; // Raw unbinned byte array
for (int i = lo; i < hi; i++) { // not-binned loop
int permIdx = permutation[i]; // Get the row
int val = (0xFF & raw[permIdx]);// raw byte value, has no bad rows
int cls = classs[permIdx] & 0xFF; // Class-for-row
cdsf[val][cls]++; // Bump histogram
}
}
}
} else {
DataAdapter.Col cols[] = _dapt._c;
float[] response;
if (cols[_dapt.classColIdx()]._binned == null) {
response = new float[cols[_dapt.classColIdx()]._rawB.length];
for (int b = 0; b < response.length; ++b)
response[b] = (float)(0xFF & cols[_dapt.classColIdx()]._rawB[b]);
} else {
response = new float[cols[_dapt.classColIdx()]._binned.length];
for (int f = 0; f < response.length; ++f)
response[f] = cols[_dapt.classColIdx()]._binned2raw[cols[_dapt.classColIdx()]._binned[f]];
}
int cds[][][] = s._columnDistsRegression;
int fs[] = s._features;
for (int f: fs) {
if (f == -1) break;
int cdsf[][] = cds[f];
short[] bins = cols[f]._binned;
if (bins != null) {
for (int i = lo; i < hi; i++) {
int permIdx = permutation[i];
int val = bins[permIdx];
if (val == DataAdapter.BAD) continue; // ignore bad rows
float resp = response[permIdx]; // Class-for-row
int response_bin = _dapt.getEncodedClassColumnValue(permIdx); //cols[cols.length-1]._binned[permIdx]; //cols[_dapt.classColIdx()]._binned == null ? (cols[_dapt.classColIdx()]._rawB[permIdx] & 0xFF) : cols[_dapt.classColIdx()]._binned[permIdx];
if (resp == DataAdapter.BAD) continue; // ignore rows with NA in response column
cdsf[val][response_bin]++; // = resp; // Bump histogram
}
} else {
byte[] raw = cols[f]._rawB;
for (int i = lo; i < hi; i++) {
int permIdx = permutation[i];
int val = raw[permIdx]&0xFF;
if (val == DataAdapter.BAD) continue;
short resp = cols[cols.length-1]._binned[permIdx];
if (resp == DataAdapter.BAD) continue;
int response_bin = _dapt.getEncodedClassColumnValue(permIdx); //cols[cols.length-1]._binned[permIdx]; //cols[_dapt.classColIdx()]._binned == null ? (cols[_dapt.classColIdx()]._rawB[permIdx] & 0xFF) : cols[_dapt.classColIdx()]._binned[permIdx];
cdsf[val][response_bin]++; // = resp;
}
}
}
}
}
public void filter(Tree.SplitNode node, Data[] result, Statistic ls, Statistic rs) {
int[] permutation = getPermutationArray();
int cidx = node._column;
int l = _dapt.hasAnyInvalid(cidx) || _dapt.hasAnyInvalid(_dapt.columns()-1)
? filterInv(node,permutation,ls,rs)
: filterVal(node,permutation,ls,rs);
ColumnInfo[] linfo = _columnInfo.clone();
ColumnInfo[] rinfo = _columnInfo.clone();
linfo[node._column]= linfo[node._column].left(node._split);
rinfo[node._column]= rinfo[node._column].right(node._split);
result[0]= new Subset(this, permutation, start(), l);
result[1]= new Subset(this, permutation, l, end());
result[0]._columnInfo = linfo;
result[1]._columnInfo = rinfo;
}
public Data sampleWithReplacement(double bagSizePct, short[] complement) {
// Make sure that values come in order
int size = (int)(rows() * bagSizePct);
/* NOTE: Before changing used generator think about which kind of random generator you need:
* if always deterministic or non-deterministic version - see hex.speedrf.Utils.get{Deter}RNG */
Random r = Utils.getRNG(seed());
for( int i = 0; i < size; ++i)
complement[permute(r.nextInt(rows()))]++;
int[] sample = MemoryManager.malloc4(size);
for( int i = 0, j = 0; i < sample.length;) {
while(complement[j]==0) j++;
for (int k = 0; k < complement[j]; k++) sample[i++] = j;
j++;
}
return new Subset(this, sample, 0, sample.length);
}
public Data complement(Data parent, short[] complement) { throw new RuntimeException("Only for subsets."); }
@Override public Data clone() throws CloneNotSupportedException { return this; }
protected int permute(int idx) { return idx; }
protected int[] getPermutationArray() {
int[] perm = MemoryManager.malloc4(rows());
for( int i = 0; i < perm.length; ++i ) perm[i] = i;
return perm;
}
public int colMinIdx(int i) { return _columnInfo[i].min; }
public int colMaxIdx(int i) { return _columnInfo[i].max; }
class ColumnInfo {
private final int col;
int min, max;
ColumnInfo(int col_) { col=col_; max = _dapt.columnArity(col_) - 1; }
ColumnInfo left(int idx) {
ColumnInfo res = new ColumnInfo(col);
res.max = idx < max ? idx : max;
res.min = min;
return res;
}
ColumnInfo right(int idx) {
ColumnInfo res = new ColumnInfo(col);
res.min = idx >= min ? (idx+1) : min;
res.max = max;
return res;
}
public String toString() { return col + "["+ min +","+ max + "]"; }
}
ColumnInfo[] _columnInfo;
}
class Subset extends Data {
private final int[] _permutation;
private final int _start, _end;
@Override protected int[] getPermutationArray() { return _permutation; }
@Override protected int permute(int idx) { return _permutation[idx]; }
@Override protected int start() { return _start; }
@Override protected int end() { return _end; }
@Override public Subset clone() throws CloneNotSupportedException { return new Subset(this,_permutation.clone(),_start,_end); }
/** Creates new subset of the given data adapter. The permutation is an array
* of original row indices of the DataAdapter object that will be used. */
public Subset(Data data, int[] permutation, int start, int end) {
super(data._dapt);
_start = start;
_end = end;
_permutation = permutation;
}
@Override public Data complement(Data parent, short[] complement) {
int size= 0;
for (short aComplement : complement) if (aComplement == 0) size++;
int[] p = MemoryManager.malloc4(size);
int pos = 0;
for(int i=0;i<complement.length; i++) if (complement[i]==0) p[pos++] = i;
return new Subset(this, p, 0, p.length);
}
}