package water.rapids; import water.*; import water.fvec.Chunk; import water.fvec.Frame; import water.fvec.Vec; import java.util.ArrayList; import static water.rapids.SingleThreadRadixOrder.getSortedOXHeaderKey; public class Merge { // Radix-sort a Frame using the given columns as keys. // This is a fully distributed and parallel sort. // It is not currently an in-place sort, so the data is doubled and a sorted copy is returned. public static Frame sort( final Frame fr, int[] cols ) { if( cols.length==0 ) // Empty key list return fr; // Return original frame for( int col : cols ) if( col < 0 || col >= fr.numCols() ) throw new IllegalArgumentException("Column "+col+" is out of range of "+fr.numCols()); // All identity ID maps int id_maps[][] = new int[cols.length][]; for( int i=0; i<cols.length; i++ ) { Vec vec = fr.vec(cols[i]); if( vec.isCategorical() ) { String[] domain = vec.domain(); id_maps[i] = new int[domain.length]; for( int j=0; j<domain.length; j++ ) id_maps[i][j] = j; } } return Merge.merge(fr, new Frame(new Vec[0]), cols, new int[0], true/*allLeft*/, id_maps); } // single-threaded driver logic. Merge left and right frames based on common columns. public static Frame merge(final Frame leftFrame, final Frame riteFrame, final int leftCols[], final int riteCols[], boolean allLeft, int[][] id_maps) { final boolean hasRite = riteCols.length > 0; // map missing levels to -1 (rather than increasing slots after the end) // for now to save a deep branch later for (int i=0; i<id_maps.length; i++) { if (id_maps[i] == null) continue; assert id_maps[i].length >= leftFrame.vec(leftCols[i]).max()+1; if( !hasRite ) continue; int right_max = (int)riteFrame.vec(riteCols[i]).max(); for (int j=0; j<id_maps[i].length; j++) { assert id_maps[i][j] >= 0; if (id_maps[i][j] > right_max) id_maps[i][j] = -1; } } // Running 3 consecutive times on an idle cluster showed that running left // and right in parallel was a little slower (97s) than one by one (89s). // TODO: retest in future RadixOrder leftIndex = createIndex(true ,leftFrame,leftCols,id_maps); RadixOrder riteIndex = createIndex(false,riteFrame,riteCols,id_maps); // TODO: start merging before all indexes had been created. Use callback? System.out.print("Making BinaryMerge RPC calls ... "); long t0 = System.nanoTime(); ArrayList<BinaryMerge> bmList = new ArrayList<>(); Futures fs = new Futures(); final int leftShift = leftIndex._shift[0]; final long leftBase = leftIndex._base[0]; final int riteShift = hasRite ? riteIndex._shift[0] : -1; final long riteBase = hasRite ? riteIndex._base [0] : leftBase; long leftMSBfrom = (riteBase - leftBase) >> leftShift; // which leftMSB does the overlap start // deal with the left range below the right minimum, if any if (leftBase < riteBase) { // deal with the range of the left below the start of the right, if any assert leftMSBfrom >= 0; if (leftMSBfrom>255) { // The left range ends before the right range starts. So every left row is a no-match to the right leftMSBfrom = 256; // so that the loop below runs for all MSBs (0-255) to fetch the left rows only } // run the merge for the whole lefts that end before the first right. // The overlapping one with the right base is dealt with inside // BinaryMerge (if _allLeft) if (allLeft) for (int leftMSB=0; leftMSB<leftMSBfrom; leftMSB++) { BinaryMerge bm = new BinaryMerge(new BinaryMerge.FFSB(leftFrame, leftMSB ,leftShift,leftIndex._bytesUsed,leftIndex._base), new BinaryMerge.FFSB(riteFrame,/*rightMSB*/-1,riteShift,riteIndex._bytesUsed,riteIndex._base), true); bmList.add(bm); fs.add(new RPC<>(SplitByMSBLocal.ownerOfMSB(leftMSB), bm).call()); } } else { // completely ignore right MSBs below the left base assert leftMSBfrom <= 0; leftMSBfrom = 0; } long leftMSBto = (riteBase + (256L<<riteShift) - 1 - leftBase) >> leftShift; // -1 because the 256L<<riteShift is one after the max extent. // No need -for +1 for NA here because, as for leftMSBfrom above, the NA spot is on -both sides // deal with the left range above the right maximum, if any if( (leftBase + (256L<<leftShift)) > (riteBase + (256L<<riteShift)) ) { assert leftMSBto <= 255; if (leftMSBto<0) { // The left range starts after the right range ends. So every left row // is a no-match to the right leftMSBto = -1; // all MSBs (0-255) need to fetch the left rows only } // run the merge for the whole lefts that start after the last right if (allLeft) for (int leftMSB=(int)leftMSBto+1; leftMSB<=255; leftMSB++) { BinaryMerge bm = new BinaryMerge(new BinaryMerge.FFSB(leftFrame, leftMSB ,leftShift,leftIndex._bytesUsed,leftIndex._base), new BinaryMerge.FFSB(riteFrame,/*rightMSB*/-1,riteShift,riteIndex._bytesUsed,riteIndex._base), true); bmList.add(bm); fs.add(new RPC<>(SplitByMSBLocal.ownerOfMSB(leftMSB), bm).call()); } } else { // completely ignore right MSBs after the right peak assert leftMSBto >= 255; leftMSBto = 255; } // the overlapped region; i.e. between [ max(leftMin,rightMin), min(leftMax, rightMax) ] for (int leftMSB=(int)leftMSBfrom; leftMSB<=leftMSBto; leftMSB++) { assert leftMSB >= 0; assert leftMSB <= 255; // calculate the key values at the bin extents: [leftFrom,leftTo] in terms of keys long leftFrom= (((long)leftMSB ) << leftShift) -1 + leftBase ; // -1 for leading NA spot long leftTo = (((long)leftMSB+1) << leftShift) -1 + leftBase-1; // -1 for leading NA spot and another -1 to get last of previous bin // which right bins do these left extents occur in (could span multiple, and fall in the middle) int rightMSBfrom = (int)((leftFrom - riteBase + 1) >> riteShift); // +1 again for the leading NA spot int rightMSBto = (int)((leftTo - riteBase + 1) >> riteShift); // the non-matching part of this region will have been dealt with above when allLeft==true if (rightMSBfrom < 0) rightMSBfrom = 0; assert rightMSBfrom <= 255; if (rightMSBto > 255) rightMSBto = 255; assert rightMSBto >= rightMSBfrom; for (int rightMSB=rightMSBfrom; rightMSB<=rightMSBto; rightMSB++) { BinaryMerge bm = new BinaryMerge(new BinaryMerge.FFSB(leftFrame, leftMSB,leftShift,leftIndex._bytesUsed,leftIndex._base), new BinaryMerge.FFSB(riteFrame,rightMSB,riteShift,riteIndex._bytesUsed,riteIndex._base), allLeft); bmList.add(bm); // TODO: choose the bigger side to execute on (where that side of index // already is) to minimize transfer. within BinaryMerge it will // recalculate the extents in terms of keys and bsearch for them within // the (then local) both sides H2ONode node = SplitByMSBLocal.ownerOfMSB(rightMSB); fs.add(new RPC<>(node, bm).call()); } } System.out.println("took: " + String.format("%.3f", (System.nanoTime() - t0) / 1e9)); t0 = System.nanoTime(); System.out.println("Sending BinaryMerge async RPC calls in a queue ... "); fs.blockForPending(); System.out.println("took: " + (System.nanoTime() - t0) / 1e9); System.out.print("Removing DKV keys of left and right index. ... "); // TODO: In future we won't delete but rather persist them as index on the table // Explicitly deleting here (rather than Arno's cleanUp) to reveal if we're not removing keys early enough elsewhere t0 = System.nanoTime(); for (int msb=0; msb<256; msb++) { for (int isLeft=0; isLeft<2; isLeft++) { Key k = getSortedOXHeaderKey(isLeft!=0, msb); SingleThreadRadixOrder.OXHeader oxheader = DKV.getGet(k); DKV.remove(k); if (oxheader != null) { for (int b=0; b<oxheader._nBatch; ++b) { k = SplitByMSBLocal.getSortedOXbatchKey(isLeft!=0, msb, b); DKV.remove(k); } } } } System.out.println("took: " + (System.nanoTime() - t0)/1e9); System.out.print("Allocating and populating chunk info (e.g. size and batch number) ..."); t0 = System.nanoTime(); long ansN = 0; int numChunks = 0; for( BinaryMerge thisbm : bmList ) if( thisbm._numRowsInResult > 0 ) { numChunks += thisbm._chunkSizes.length; ansN += thisbm._numRowsInResult; } long chunkSizes[] = new long[numChunks]; int chunkLeftMSB[] = new int[numChunks]; // using too much space repeating the same value here, but, limited int chunkRightMSB[] = new int[numChunks]; int chunkBatch[] = new int[numChunks]; int k = 0; for( BinaryMerge thisbm : bmList ) { if (thisbm._numRowsInResult == 0) continue; int thisChunkSizes[] = thisbm._chunkSizes; for (int j=0; j<thisChunkSizes.length; j++) { chunkSizes[k] = thisChunkSizes[j]; chunkLeftMSB [k] = thisbm._leftSB._msb; chunkRightMSB[k] = thisbm._riteSB._msb; chunkBatch[k] = j; k++; } } System.out.println("took: " + (System.nanoTime() - t0) / 1e9); // Now we can stitch together the final frame from the raw chunks that were // put into the store System.out.print("Allocating and populated espc ..."); t0 = System.nanoTime(); long espc[] = new long[chunkSizes.length+1]; int i=0; long sum=0; for (long s : chunkSizes) { espc[i++] = sum; sum+=s; } espc[espc.length-1] = sum; System.out.println("took: " + (System.nanoTime() - t0) / 1e9); assert(sum==ansN); System.out.print("Allocating dummy vecs/chunks of the final frame ..."); t0 = System.nanoTime(); int numJoinCols = hasRite ? leftIndex._bytesUsed.length : 0; int numLeftCols = leftFrame.numCols(); int numColsInResult = numLeftCols + riteFrame.numCols() - numJoinCols ; final byte[] types = new byte[numColsInResult]; final String[][] doms = new String[numColsInResult][]; final String[] names = new String[numColsInResult]; for (int j=0; j<numLeftCols; j++) { types[j] = leftFrame.vec(j).get_type(); doms[j] = leftFrame.domains()[j]; names[j] = leftFrame.names()[j]; } for (int j=0; j<riteFrame.numCols()-numJoinCols; j++) { types[numLeftCols + j] = riteFrame.vec(j+numJoinCols).get_type(); doms[numLeftCols + j] = riteFrame.domains()[j+numJoinCols]; names[numLeftCols + j] = riteFrame.names()[j+numJoinCols]; } Key<Vec> key = Vec.newKey(); Vec[] vecs = new Vec(key, Vec.ESPC.rowLayout(key, espc)).makeCons(numColsInResult, 0, doms, types); System.out.println("took: " + (System.nanoTime() - t0) / 1e9); System.out.print("Finally stitch together by overwriting dummies ..."); t0 = System.nanoTime(); Frame fr = new Frame(names, vecs); ChunkStitcher ff = new ChunkStitcher(chunkSizes, chunkLeftMSB, chunkRightMSB, chunkBatch); ff.doAll(fr); System.out.println("took: " + (System.nanoTime() - t0) / 1e9); //Merge.cleanUp(); return fr; } private static RadixOrder createIndex(boolean isLeft, Frame fr, int[] cols, int[][] id_maps) { System.out.println("\nCreating "+(isLeft ? "left" : "right")+" index ..."); long t0 = System.nanoTime(); RadixOrder idxTask = new RadixOrder(fr, isLeft, cols, id_maps); H2O.submitTask(idxTask); // each of those launches an MRTask idxTask.join(); System.out.println("***\n*** Creating "+(isLeft ? "left" : "right")+" index took: " + (System.nanoTime() - t0) / 1e9 + "\n***\n"); return idxTask; } static class ChunkStitcher extends MRTask<ChunkStitcher> { final long _chunkSizes[]; final int _chunkLeftMSB[]; final int _chunkRightMSB[]; final int _chunkBatch[]; ChunkStitcher(long[] chunkSizes, int[] chunkLeftMSB, int[] chunkRightMSB, int[] chunkBatch ) { _chunkSizes = chunkSizes; _chunkLeftMSB = chunkLeftMSB; _chunkRightMSB= chunkRightMSB; _chunkBatch = chunkBatch; } @Override public void map(Chunk[] cs) { int chkIdx = cs[0].cidx(); Futures fs = new Futures(); for (int i=0;i<cs.length;++i) { Key destKey = cs[i].vec().chunkKey(chkIdx); assert(cs[i].len() == _chunkSizes[chkIdx]); Key k = BinaryMerge.getKeyForMSBComboPerCol(_chunkLeftMSB[chkIdx], _chunkRightMSB[chkIdx], i, _chunkBatch[chkIdx]); Chunk ck = DKV.getGet(k); DKV.put(destKey, ck, fs, /*don't cache*/true); DKV.remove(k); } fs.blockForPending(); } } }