package water.rapids.ast.params;
import water.H2O;
import water.rapids.Env;
import water.rapids.Val;
import water.rapids.ast.AstParameter;
import water.util.ArrayUtils;
import water.util.SB;
import java.util.ArrayList;
import java.util.Arrays;
/**
* A collection of base/stride/cnts.
* Syntax: { {num | num:cnt | num:cnt:stride},* }
* <p/>
* The bases can be unordered with dups (often used for column selection where
* repeated columns are allowed, and order matters). The _isList flag tracks
* that all cnts are 1 (and hence all strides are ignored and 1); these lists
* may or may not be sorted. Note that some column selection is dense
* (typical all-columns is: {0:MAX_INT}), and this has cnt>1.
* <p/>
* When cnts are > 1, bases must be sorted, with base+stride*cnt always less
* than the next base. Typical use-case might be a list of probabilities for
* computing quantiles, or grid-search parameters.
* <p/>
* Asking for a sorted integer expansion will sort the bases internally, and
* also demand no overlap between bases. The has(), min() and max() calls
* require a sorted list.
*/
public class AstNumList extends AstParameter {
public final double[] _bases;
final double _strides[];
final long _cnts[];
public final boolean _isList; // True if an unordered list of numbers (cnts are 1, stride is ignored)
public boolean _isSort; // True if bases are sorted. May get updated later.
public AstNumList(ArrayList<Double> bases, ArrayList<Double> strides, ArrayList<Long> counts) {
int n = bases.size();
// Convert to fixed-sized arrays
_bases = new double[n];
_strides = new double[n];
_cnts = new long[n];
boolean isList = true;
for (int i = 0; i < n; i++) {
_bases[i] = bases.get(i);
_cnts[i] = counts.get(i);
_strides[i] = strides.get(i);
if (_cnts[i] != 1) isList = false;
}
_isList = isList;
// Complain about unordered bases, unless it's a simple number list
boolean isSorted = true;
for (int i = 1; i < n; i++)
if (_bases[i-1] + (_cnts[i-1] - 1) * _strides[i-1] >= _bases[i]) {
if (_isList) isSorted = false;
else throw new IllegalArgumentException("Overlapping numeric ranges");
}
_isSort = isSorted;
}
// A simple AstNumList of 1 number
public AstNumList(double d) {
_bases = new double[]{d};
_strides = new double[]{1};
_cnts = new long[]{1};
_isList = _isSort = true;
}
// A simple dense range AstNumList
public AstNumList(long lo, long hi_exclusive) {
_bases = new double[]{lo};
_strides = new double[]{1};
_cnts = new long[]{hi_exclusive - lo};
_isList = false;
_isSort = true;
}
// An empty number list
public AstNumList() {
_bases = new double[0];
_strides = new double[0];
_cnts = new long[0];
_isList = _isSort = true;
}
public AstNumList(double[] list) {
_bases = list;
_strides = new double[list.length];
_cnts = new long[list.length];
_isList = true;
Arrays.fill(_strides, 1);
Arrays.fill(_cnts, 1);
}
public AstNumList(int[] list) {
this(ArrayUtils.copyFromIntArray(list));
}
// This is a special syntatic form; the number-list never executes and hits
// the execution stack
@Override
public Val exec(Env env) {
throw new IllegalArgumentException("Number list not allowed here");
}
@Override
public String str() {
SB sb = new SB().p('[');
for (int i = 0; i < _bases.length; i++) {
sb.p(_bases[i]);
if (_cnts[i] != 1) {
sb.p(':').p(_bases[i] + _cnts[i] * _strides[i]);
if (_strides[i] != 1 || ((long) _bases[i]) != _bases[i])
sb.p(':').p(_strides[i]);
}
if (i < _bases.length - 1) sb.p(',');
}
return sb.p(']').toString();
}
@Override
public String toJavaString() {
double[] ary = expand();
if (ary == null || ary.length == 0) return "\"null\"";
SB sb = new SB().p('{');
for (int i = 0; i < ary.length - 1; ++i) sb.p(ary[i]).p(',');
return sb.p('}').toString();
}
// Expand the compressed form into an array of doubles.
public double[] expand() {
// Count total values
int nrows = (int) cnt(), r = 0;
// Fill in values
double[] vals = new double[nrows];
for (int i = 0; i < _bases.length; i++) {
if (Double.isNaN(_bases[i])) {
vals[r++] = Double.NaN;
} else {
for (double d = _bases[i]; d < _bases[i] + _cnts[i] * _strides[i]; d += _strides[i])
vals[r++] = d;
}
}
return vals;
}
// Update-in-place sort of bases
public AstNumList sort() {
if (_isSort) return this; // Flow coding fast-path cutout
int[] idxs = ArrayUtils.seq(0, _bases.length);
ArrayUtils.sort(idxs, _bases);
double[] bases = _bases.clone();
double[] strides = _strides.clone();
long[] cnts = _cnts.clone();
for (int i = 0; i < idxs.length; i++) {
_bases[i] = bases[idxs[i]];
_strides[i] = strides[idxs[i]];
_cnts[i] = cnts[idxs[i]];
}
_isSort = true;
return this;
}
// Expand the compressed form into an array of ints;
// often used for unordered column lists
public int[] expand4() {
// Count total values
int nrows = (int) cnt(), r = 0;
// Fill in values
int[] vals = new int[nrows];
for (int i = 0; i < _bases.length; i++)
for (double d = _bases[i]; d < _bases[i] + _cnts[i] * _strides[i]; d += _strides[i])
vals[r++] = (int) d;
return vals;
}
// Expand the compressed form into an array of ints;
// often used for sorted column lists
int[] expand4Sort() {
return sort().expand4();
}
// Expand the compressed form into an array of longs;
// often used for unordered row lists
public long[] expand8() {
// Count total values
int nrows = (int) cnt(), r = 0;
// Fill in values
long[] vals = new long[nrows];
for (int i = 0; i < _bases.length; i++)
for (double d = _bases[i]; d < _bases[i] + _cnts[i] * _strides[i]; d += _strides[i])
vals[r++] = (long) d;
return vals;
}
// Expand the compressed form into an array of longs;
// often used for sorted row lists
public long[] expand8Sort() {
return sort().expand8();
}
public double max() {
assert _isSort;
return _bases[_bases.length - 1] + _cnts[_cnts.length - 1] * _strides[_strides.length - 1];
} // largest exclusive value (weird rite?!)
public double min() {
assert _isSort;
return _bases[0];
}
public long cnt() {
return water.util.ArrayUtils.sum(_cnts);
}
public boolean isDense() {
return _cnts.length == 1 && _bases[0] == 0 && _strides[0] == 1;
}
public boolean isEmpty() {
return _bases.length == 0;
}
// check if n is in this list of numbers
// NB: all contiguous ranges have already been checked to have stride 1
public boolean has(long v) {
int idx = findBase(v);
if (idx >= 0) return true;
idx = -idx - 2; // See Arrays.binarySearch; returns (-idx-1), we want +idx-1 ... if idx == -1 => then this transformation has no effect
if (idx < 0) return false;
assert _bases[idx] < v; // Sanity check binary search, AND idx >= 0
return v < _bases[idx] + _cnts[idx] * _strides[idx] && (v - _bases[idx]) % _strides[idx] == 0;
}
/**
* Finds index of a given value in this number sequence, indexing start at 0.
* @param v value
* @return value index (>= 0) or -1 if value is not a member of this sequence
*/
public long index(long v) {
int bIdx = findBase(v);
if (bIdx >= 0) return water.util.ArrayUtils.sum(_cnts, 0, bIdx - 1);
bIdx = -bIdx - 2;
if (bIdx < 0) return -1L;
assert _bases[bIdx] < v;
long offset = v - (long) _bases[bIdx];
long stride = (long) _strides[bIdx];
if ((offset >= _cnts[bIdx] * stride) || (offset % stride != 0)) return -1L;
return water.util.ArrayUtils.sum(_cnts, 0, bIdx) + (offset / stride);
}
private int findBase(long v) {
assert _isSort; // Only called when already sorted
// do something special for negative indexing... that does not involve
// allocating arrays, once per list element!
if (v < 0) throw H2O.unimpl();
return Arrays.binarySearch(_bases, v);
}
// Select columns by number. Numbers are capped to the number of columns +1
// - this allows R to see a single out-of-range value and throw a range check
// - this allows Python to see a single out-of-range value and ignore it
// - this allows Python to pass [0:MAXINT] without blowing out the max number of columns.
// Note that the Python front-end does not want to cap the max column size, because
// this will force eager evaluation on a standard column slice operation.
// Note that the list is often unsorted (_isSort is false).
// Note that the list is often dense with cnts>1 (_isList is false).
@Override
public int[] columns(String[] names) {
// Count total values, capped by max len+1
int nrows = 0, r = 0;
for (int i = 0; i < _bases.length; i++)
nrows += Math.min(_bases[i] + _cnts[i], names.length + 1) - Math.min(_bases[i], names.length + 1);
// Fill in values
int[] vals = new int[nrows];
for (int i = 0; i < _bases.length; i++) {
int lim = Math.min((int) (_bases[i] + _cnts[i]), names.length + 1);
for (int d = Math.min((int) _bases[i], names.length + 1); d < lim; d++)
vals[r++] = d;
}
return vals;
}
}