package water.rapids;
// General principle here is that several parallel, tight, branch free loops,
// faster than one heavy DKV pass per row
// It is intended that several of these SingleThreadRadixOrder run on the same
// node, to utilize the cores available. The initial MSB needs to split by num
// nodes * cpus per node; e.g. 256 is pretty good for 10 nodes of 32 cores.
// Later, use 9 bits, or a few more bits accordingly.
// Its this 256 * 4kB = 1MB that needs to be < cache per core for cache write
// efficiency in MoveByFirstByte(). 10 bits (1024 threads) would be 4MB which
// still < L2
// Since o[] and x[] are arrays here (not Vecs) it's harder to see how to
// parallelize inside this function. Therefore avoid that issue by using more
// threads in calling split.
import water.*;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import java.util.Arrays;
class SingleThreadRadixOrder extends DTask<SingleThreadRadixOrder> {
private final Frame _fr;
private final int _MSBvalue; // only needed to be able to return the number of groups back to the caller RadixOrder
private final int _keySize, _batchSize;
private final boolean _isLeft;
private transient long _o[/*batch*/][];
private transient byte _x[/*batch*/][];
private transient long _otmp[][];
private transient byte _xtmp[][];
// TEMPs
private transient long counts[][];
private transient byte keytmp[];
//public long _groupSizes[][];
// outputs ...
// o and x are changed in-place always
// iff _groupsToo==true then the following are allocated and returned
SingleThreadRadixOrder(Frame fr, boolean isLeft, int batchSize, int keySize, /*long nGroup[],*/ int MSBvalue) {
_fr = fr;
_isLeft = isLeft;
_batchSize = batchSize;
_keySize = keySize;
_MSBvalue = MSBvalue;
}
@Override
public void compute2() {
keytmp = new byte[_keySize];
counts = new long[_keySize][256];
Key k;
SplitByMSBLocal.MSBNodeHeader[] MSBnodeHeader = new SplitByMSBLocal.MSBNodeHeader[H2O.CLOUD.size()];
long numRows =0;
for (int n=0; n<H2O.CLOUD.size(); n++) {
// Log.info("Getting MSB " + MSBvalue + " Node Header from node " + n + "/" + H2O.CLOUD.size() + " for Frame " + _fr._key);
// Log.info("Getting");
k = SplitByMSBLocal.getMSBNodeHeaderKey(_isLeft, _MSBvalue, n);
MSBnodeHeader[n] = DKV.getGet(k);
if (MSBnodeHeader[n]==null) continue;
DKV.remove(k);
numRows += ArrayUtils.sum(MSBnodeHeader[n]._MSBnodeChunkCounts); // This numRows is split into nbatch batches on that node.
// This header has the counts of each chunk (the ordered chunk numbers on that node)
}
if (numRows == 0) { tryComplete(); return; }
// Allocate final _o and _x for this MSB which is gathered together on this
// node from the other nodes.
// TO DO: as Arno suggested, wrap up into class for fixed width batching
// (to save espc overhead)
int nbatch = (int) ((numRows-1) / _batchSize +1); // at least one batch.
// the size of the last batch (could be batchSize, too if happens to be
// exact multiple of batchSize)
int lastSize = (int) (numRows - (nbatch-1)*_batchSize);
_o = new long[nbatch][];
_x = new byte[nbatch][];
int b;
for (b = 0; b < nbatch-1; b++) {
_o[b] = new long[_batchSize]; // TO DO?: use MemoryManager.malloc8()
_x[b] = new byte[_batchSize * _keySize];
}
_o[b] = new long[lastSize];
_x[b] = new byte[lastSize * _keySize];
SplitByMSBLocal.OXbatch ox[/*node*/] = new SplitByMSBLocal.OXbatch[H2O.CLOUD.size()];
int oxBatchNum[/*node*/] = new int[H2O.CLOUD.size()]; // which batch of OX are we on from that node? Initialized to 0.
for (int node=0; node<H2O.CLOUD.size(); node++) { //TO DO: why is this serial? Relying on
k = SplitByMSBLocal.getNodeOXbatchKey(_isLeft, _MSBvalue, node, /*batch=*/0);
// assert k.home(); // TODO: PUBDEV-3074
ox[node] = DKV.getGet(k); // get the first batch for each node for this MSB
DKV.remove(k);
}
int oxOffset[] = new int[H2O.CLOUD.size()];
int oxChunkIdx[] = new int[H2O.CLOUD.size()]; // that node has n chunks and which of those are we currently on?
int targetBatch = 0, targetOffset = 0, targetBatchRemaining = _batchSize;
final Vec vec = _fr.anyVec();
assert vec != null;
for (int c=0; c<vec.nChunks(); c++) {
int fromNode = vec.chunkKey(c).home_node().index(); // each chunk in the column may be on different nodes
// See long comment at the top of SendSplitMSB. One line from there repeated here :
// " When the helper node (i.e. this one, now) (i.e the node doing all
// the A's) gets the A's from that node, it must stack all the nodes' A's
// with the A's from the other nodes in chunk order in order to maintain
// the original order of the A's within the global table. "
// TODO: We could process these in node order and or/in parallel if we
// cumulated the counts first to know the offsets - should be doable and
// high value
if (MSBnodeHeader[fromNode] == null) continue;
// magically this works, given the outer for loop through global
// chunk. Relies on LINE_ANCHOR_1 above.
int numRowsToCopy = MSBnodeHeader[fromNode]._MSBnodeChunkCounts[oxChunkIdx[fromNode]++];
// _MSBnodeChunkCounts is a vector of the number of contributions from
// each Vec chunk. Since each chunk is length int, this must less than
// that, so int The set of data corresponding to the Vec chunk
// contributions is stored packed in batched vectors _o and _x.
// at most batchSize remaining. No need to actually put the number of rows left in here
int sourceBatchRemaining = _batchSize - oxOffset[fromNode];
while (numRowsToCopy > 0) { // No need for class now, as this is a bit different to the other batch copier. Two isn't too bad.
int thisCopy = Math.min(numRowsToCopy, Math.min(sourceBatchRemaining, targetBatchRemaining));
System.arraycopy(ox[fromNode]._o, oxOffset[fromNode], _o[targetBatch], targetOffset, thisCopy);
System.arraycopy(ox[fromNode]._x, oxOffset[fromNode]*_keySize, _x[targetBatch], targetOffset*_keySize, thisCopy*_keySize);
numRowsToCopy -= thisCopy;
oxOffset[fromNode] += thisCopy; sourceBatchRemaining -= thisCopy;
targetOffset += thisCopy; targetBatchRemaining -= thisCopy;
if (sourceBatchRemaining == 0) {
// fetch the next batch :
k = SplitByMSBLocal.getNodeOXbatchKey(_isLeft, _MSBvalue, fromNode, ++oxBatchNum[fromNode]);
assert k.home();
ox[fromNode] = DKV.getGet(k);
DKV.remove(k);
if (ox[fromNode] == null) {
// if the last chunksworth fills a batchsize exactly, the getGet above will have returned null.
// TODO: Check will Cliff that a known fetch of a non-existent key is ok e.g. won't cause a delay/block? If ok, leave as good check.
int numNonZero = 0; for (int tmp : MSBnodeHeader[fromNode]._MSBnodeChunkCounts) if (tmp>0) numNonZero++;
assert oxBatchNum[fromNode]==numNonZero;
assert ArrayUtils.sum(MSBnodeHeader[fromNode]._MSBnodeChunkCounts) % _batchSize == 0;
}
oxOffset[fromNode] = 0;
sourceBatchRemaining = _batchSize;
}
if (targetBatchRemaining == 0) {
targetBatch++;
targetOffset = 0;
targetBatchRemaining = _batchSize;
}
}
}
// We now have _o and _x collated from all the contributing nodes, in the correct original order.
// TODO save this allocation and reuse per thread? Or will heap just take care of it. Time this allocation and copy as step 1 anyway.
_xtmp = new byte[_x.length][];
_otmp = new long[_o.length][];
assert _x.length == _o.length; // i.e. aligned batch size between x and o (think 20 bytes keys and 8 bytes of long in o)
// Seems like no deep clone available in Java. Maybe System.arraycopy but
// maybe that needs target to be allocated first
for (int i=0; i<_x.length; i++) {
_xtmp[i] = Arrays.copyOf(_x[i], _x[i].length);
_otmp[i] = Arrays.copyOf(_o[i], _o[i].length);
}
// TO DO: a way to share this working memory between threads.
// Just create enough for the 4 threads active at any one time. Not 256 allocations and releases.
// We need o[] and x[] in full for the result. But this way we don't need full size xtmp[] and otmp[] at any single time.
// Currently Java will allocate and free these xtmp and otmp and maybe it does good enough job reusing heap that we don't need to explicitly optimize this reuse.
// Perhaps iterating this task through the largest bins first will help java reuse heap.
assert(_o != null);
assert(numRows > 0);
// The main work. Radix sort this batch ...
run(0, numRows, _keySize-1); // if keySize is 6 bytes, first byte is byte 5
// don't need to clear these now using private transient
// _counts = null;
// keytmp = null;
//_nGroup = null;
// tell the world how many batches and rows for this MSB
OXHeader msbh = new OXHeader(_o.length, numRows, _batchSize);
Futures fs = new Futures();
DKV.put(getSortedOXHeaderKey(_isLeft, _MSBvalue), msbh, fs, true);
assert _o.length == _x.length;
for (b=0; b<_o.length; b++) {
SplitByMSBLocal.OXbatch tmp = new SplitByMSBLocal.OXbatch(_o[b], _x[b]);
Value v = new Value(SplitByMSBLocal.getSortedOXbatchKey(_isLeft, _MSBvalue, b), tmp);
DKV.put(v._key, v, fs, true); // the OXbatchKey's on this node will be reused for the new keys
v.freeMem();
}
// TODO: check numRows is the total of the _x[b] lengths
fs.blockForPending();
tryComplete();
}
static Key getSortedOXHeaderKey(boolean isLeft, int MSBvalue) {
// This guy has merges together data from all nodes and its data is not "from"
// any particular node. Therefore node number should not be in the key.
return Key.make("__radix_order__SortedOXHeader_MSB" + MSBvalue + (isLeft ? "_LEFT" : "_RIGHT")); // If we don't say this it's random ... (byte) 1 /*replica factor*/, (byte) 31 /*hidden user-key*/, true, H2O.SELF);
}
static class OXHeader extends Iced<OXHeader> {
OXHeader(int batches, long numRows, int batchSize) { _nBatch = batches; _numRows = numRows; _batchSize = batchSize; }
final int _nBatch;
final long _numRows;
final int _batchSize;
}
private int keycmp(byte x[], int xi, byte y[], int yi) {
// Same return value as strcmp in C. <0 => xi<yi
xi *= _keySize; yi *= _keySize;
int len = _keySize;
while (len > 1 && x[xi] == y[yi]) { xi++; yi++; len--; }
return ((x[xi] & 0xFF) - (y[yi] & 0xFF)); // 0xFF for getting back from -1 to 255
}
// orders both x and o by reference in-place. Fast for small vectors, low
// overhead. don't be tempted to binsearch backwards here because have to
// shift anyway
public void insert(long start, /*only for small len so len can be type int*/int len) {
int batch0 = (int) (start / _batchSize);
int batch1 = (int) ((start+len-1) / _batchSize);
long origstart = start; // just for when straddle batch boundaries
int len0 = 0; // same
byte _xbatch[];
long _obatch[];
if (batch1 != batch0) {
// small len straddles a batch boundary. Unlikely very often since len<=200
assert batch0 == batch1-1;
len0 = _batchSize - (int)(start % _batchSize);
// copy two halves to contiguous temp memory, do the below, then split it back to the two halves afterwards.
// Straddles batches very rarely (at most once per batch) so no speed impact at all.
_xbatch = new byte[len * _keySize];
System.arraycopy(_xbatch, 0, _x[batch0], (int)((start % _batchSize)*_keySize), len0*_keySize);
System.arraycopy(_xbatch, len0*_keySize, _x[batch1], 0, (len-len0)*_keySize);
_obatch = new long[len];
System.arraycopy(_obatch, 0, _o[batch0], (int)(start % _batchSize), len0);
System.arraycopy(_obatch, len0, _o[batch1], 0, len-len0);
start = 0;
} else {
_xbatch = _x[batch0]; // taking this outside the loop does indeed make quite a big different (hotspot isn't catching this, then)
_obatch = _o[batch0];
}
int offset = (int) (start % _batchSize);
for (int i=1; i<len; i++) {
int cmp = keycmp(_xbatch, offset+i, _xbatch, offset+i-1); // TO DO: we don't need to compare the whole key here. Set cmpLen < keySize
if (cmp < 0) {
System.arraycopy(_xbatch, (offset+i)*_keySize, keytmp, 0, _keySize);
int j = i-1;
long otmp = _obatch[offset+i];
do {
System.arraycopy(_xbatch, (offset+j)*_keySize, _xbatch, (offset+j+1)*_keySize, _keySize);
_obatch[offset+j+1] = _obatch[offset+j];
j--;
} while (j >= 0 && keycmp(keytmp, 0, _xbatch, offset+j)<0);
System.arraycopy(keytmp, 0, _xbatch, (offset+j+1)*_keySize, _keySize);
_obatch[offset + j + 1] = otmp;
}
}
if (batch1 != batch0) {
// Put the sorted data back into original two places straddling the boundary
System.arraycopy(_x[batch0], (int)(origstart % _batchSize) *_keySize, _xbatch, 0, len0*_keySize);
System.arraycopy(_x[batch1], 0, _xbatch, len0*_keySize, (len-len0)*_keySize);
System.arraycopy(_o[batch0], (int)(origstart % _batchSize), _obatch, 0, len0);
System.arraycopy(_o[batch1], 0, _obatch, len0, len-len0);
}
}
public void run(final long start, final long len, final int Byte) {
if (len < 200) { // N_SMALL=200 is guess based on limited testing. Needs calibrate().
// Was 50 based on sum(1:50)=1275 worst -vs- 256 cummulate + 256 memset +
// allowance since reverse order is unlikely.
insert(start, (int)len); // when nalast==0, iinsert will be called only from within iradix.
// TO DO: inside insert it doesn't need to compare the bytes so far as
// they're known equal, so pass Byte (NB: not Byte-1) through to insert()
// TO DO: Maybe transposing keys to be a set of _keySize byte columns
// might in fact be quicker - no harm trying. What about long and varying
// length string keys?
return;
}
final int batch0 = (int) (start / _batchSize);
final int batch1 = (int) ((start+len-1) / _batchSize);
// could well span more than one boundary when very large number of rows.
final long thisHist[] = counts[Byte];
// thisHist reused and carefully set back to 0 below so we don't need to clear it now
int idx = (int)(start%_batchSize)*_keySize + _keySize-Byte-1;
int bin=-1; // the last bin incremented. Just to see if there is only one bin with a count.
int thisLen = (int)Math.min(len, _batchSize - start%_batchSize);
final int nbatch = batch1-batch0+1; // number of batches this span of len covers. Usually 1. Minimum 1.
for (int b=0; b<nbatch; b++) {
// taking this outside the loop below does indeed make quite a big different (hotspot isn't catching this, then)
byte _xbatch[] = _x[batch0+b];
for (int i = 0; i < thisLen; i++) {
bin = 0xff & _xbatch[idx];
thisHist[bin]++;
idx += _keySize;
// maybe TO DO: shorten key by 1 byte on each iteration, so we only
// need to thisx && 0xFF. No, because we need for construction of
// final table key columns.
}
idx = _keySize-Byte-1;
thisLen = (b==nbatch-2/*next iteration will be last batch*/ ? (int)((start+len)%_batchSize) : _batchSize);
// thisLen will be set to _batchSize for the middle batches when nbatch>=3
}
if (thisHist[bin] == len) {
// one bin has count len and the rest zero => next byte quick
thisHist[bin] = 0; // important, clear for reuse
if (Byte != 0)
run(start, len, Byte-1);
return;
}
long rollSum = 0;
for (int c = 0; c < 256; c++) {
final long tmp = thisHist[c];
// important to skip zeros for logic below to undo cumulate. Worth the
// branch to save a deeply iterative memset back to zero
if (tmp == 0) continue;
thisHist[c] = rollSum;
rollSum += tmp;
}
// Sigh. Now deal with batches here as well because Java doesn't have 64bit indexing.
int oidx = (int)(start%_batchSize);
int xidx = oidx*_keySize + _keySize-Byte-1;
thisLen = (int)Math.min(len, _batchSize - start%_batchSize);
for (int b=0; b<nbatch; b++) {
// taking these outside the loop below does indeed make quite a big
// different (hotspot isn't catching this, then)
final long _obatch[] = _o[batch0+b];
final byte _xbatch[] = _x[batch0+b];
for (int i = 0; i < thisLen; i++) {
long target = thisHist[0xff & _xbatch[xidx]]++;
// now always write to the beginning of _otmp and _xtmp just to reuse the first hot pages
_otmp[(int)(target/_batchSize)][(int)(target%_batchSize)] = _obatch[oidx+i]; // this must be kept in 8 bytes longs
System.arraycopy(_xbatch, (oidx+i)*_keySize, _xtmp[(int)(target/_batchSize)], (int)(target%_batchSize)*_keySize, _keySize );
xidx += _keySize;
// Maybe TO DO: this can be variable byte width and smaller widths as
// descend through bytes (TO DO: reverse byte order so always doing &0xFF)
}
xidx = _keySize-Byte-1;
oidx = 0;
thisLen = (b==nbatch-2/*next iteration will be last batch*/ ? (int)((start+len)%_batchSize) : _batchSize);
}
// now copy _otmp and _xtmp back over _o and _x from the start position, allowing for boundaries
// _o, _x, _otmp and _xtmp all have the same _batchsize
runCopy(start,len,_keySize,_batchSize,_otmp,_xtmp,_o,_x);
long itmp = 0;
for (int i=0; i<256; i++) {
if (thisHist[i]==0) continue;
final long thisgrpn = thisHist[i] - itmp;
if( !(thisgrpn == 1 || Byte == 0) )
run(start+itmp, thisgrpn, Byte-1);
itmp = thisHist[i];
thisHist[i] = 0; // important, to save clearing counts on next iteration
}
}
// Hot loop, pulled out from the main run code
private static void runCopy(final long start, final long len, final int keySize, final int batchSize, final long otmp[][], final byte xtmp[][], final long o[][], final byte x[][]) {
// now copy _otmp and _xtmp back over _o and _x from the start position, allowing for boundaries
// _o, _x, _otmp and _xtmp all have the same _batchsize
// Would be really nice if Java had 64bit indexing to save programmer time.
long numRowsToCopy = len;
int sourceBatch = 0, sourceOffset = 0;
int targetBatch = (int)(start / batchSize), targetOffset = (int)(start % batchSize);
int targetBatchRemaining = batchSize - targetOffset; // 'remaining' means of the the full batch, not of the numRowsToCopy
int sourceBatchRemaining = batchSize - sourceOffset; // at most batchSize remaining. No need to actually put the number of rows left in here
while (numRowsToCopy > 0) { // TO DO: put this into class as well, to ArrayCopy into batched
final int thisCopy = (int)Math.min(numRowsToCopy, Math.min(sourceBatchRemaining, targetBatchRemaining));
System.arraycopy(otmp[sourceBatch], sourceOffset, o[targetBatch], targetOffset, thisCopy);
System.arraycopy(xtmp[sourceBatch], sourceOffset*keySize, x[targetBatch], targetOffset*keySize, thisCopy*keySize);
numRowsToCopy -= thisCopy;
sourceOffset += thisCopy; sourceBatchRemaining -= thisCopy;
targetOffset += thisCopy; targetBatchRemaining -= thisCopy;
if (sourceBatchRemaining == 0) { sourceBatch++; sourceOffset = 0; sourceBatchRemaining = batchSize; }
if (targetBatchRemaining == 0) { targetBatch++; targetOffset = 0; targetBatchRemaining = batchSize; }
// 'source' and 'target' deliberately the same length variable names and long lines deliberately used so we
// can easy match them up vertically to ensure they are the same
}
}
}