package water.rapids; import org.junit.BeforeClass; import org.junit.Test; import water.*; import water.fvec.*; import water.nbhm.NonBlockingHashMapLong; import water.rapids.vals.ValFrame; import java.io.IOException; import java.util.Random; import static org.junit.Assert.assertTrue; public class SortTest extends TestUtil { @BeforeClass public static void setup() { stall_till_cloudsize(1); } @Test public void testBasicSortRapids() { Frame fr = null, res = null; // Stable sort columns 1 and 2 String tree = "(sort hex [1 2])"; try { // Build a frame which is unsorted on small-count categoricals in columns // 0 and 1, and completely sorted on a record-number based column 2. // Sort will be on columns 0 and 1, in that order, and is expected stable. fr = buildFrame(1000,10); fr.insertVec(0,"row",fr.remove(2)); // Val val = Rapids.exec(tree); assertTrue( val instanceof ValFrame); res = val.getFrame(); res.add("row",res.remove(0)); new CheckSort().doAll(res); } finally { if( fr != null ) fr .delete(); if( res != null ) res.delete(); } } @Test public void testBasicSortJava() { Frame fr = null, res = null; try { fr = buildFrame(1000,10); fr.insertVec(0,"row",fr.remove(2)); res = Merge.sort(fr,new int[]{1,2}); res.add("row",res.remove(0)); new CheckSort().doAll(res); } finally { if( fr != null ) fr .delete(); if( res != null ) res.delete(); } } @Test public void testBasicSortJava2() { Frame fr = null, res = null; try { fr = buildFrame(1000,10); String[] domain = new String[1000]; for( int i=0; i<1000; i++ ) domain[i] = "D"+i; fr.vec(0).setDomain(domain); res = fr.sort(new int[]{0,1}); new CheckSort().doAll(res); } finally { if( fr != null ) fr .delete(); if( res != null ) res.delete(); } } // Assert that result is indeed sorted - on all 3 columns, as this is a // stable sort. private class CheckSort extends MRTask<CheckSort> { @Override public void map( Chunk cs[] ) { long x0 = cs[0].at8(0); long x1 = cs[1].at8(0); long x2 = cs[2].at8(0); for( int i=1; i<cs[0]._len; i++ ) { long y0 = cs[0].at8(i); long y1 = cs[1].at8(i); long y2 = cs[2].at8(i); assertTrue(x0<y0 || (x0==y0 && (x1<y1 || (x1==y1 && x2<y2)))); x0=y0; x1=y1; x2=y2; } // Last row of chunk is sorted relative to 1st row of next chunk long row = cs[0].start()+cs[0]._len; if( row < cs[0].vec().length() ) { long y0 = cs[0].vec().at8(row); long y1 = cs[1].vec().at8(row); long y2 = cs[2].vec().at8(row); assertTrue(x0<y0 || (x0==y0 && (x1<y1 || (x1==y1 && x2<y2)))); } } } // Build a 3 column frame. Col #0 is categorical with # of cats given; col // #1 is categorical with 10x more choices. A set of pairs of col#0 and // col#1 is made; each pair is given about 100 rows. Col#2 is a row number. private static Frame buildFrame( int card0, int nChunks ) { // Compute the pairs int scale0 = 3; // approximate ratio actual pairs vs all possible pairs; so scale0=3/scale1=10 is about 30% actual unique pairs int scale1 = 10; // scale of |col#1| / |col#0|, i.e., col#1 has 10x more levels than col#0 int scale2 = 100; // number of rows per pair if( nChunks == -1 ) { long len = (long)card0*(long)scale0*(long)scale2; int rowsPerChunk = 100000; nChunks = (int)((len+rowsPerChunk-1)/rowsPerChunk); } NonBlockingHashMapLong<String> pairs_hash = new NonBlockingHashMapLong<>(); Random R = new Random(card0*scale0*nChunks); for( int i=0; i<card0*scale0; i++ ) { long pair = (((long)R.nextInt(card0))<<32) | (R.nextInt(card0*scale1)); if( pairs_hash.containsKey(pair) ) i--; // Reroll dice on collisions else pairs_hash.put(pair,""); } long[] pairs = pairs_hash.keySetLong(); Key[] keys = new Vec.VectorGroup().addVecs(3); AppendableVec col0 = new AppendableVec(keys[0], Vec.T_NUM); AppendableVec col1 = new AppendableVec(keys[1], Vec.T_NUM); AppendableVec col2 = new AppendableVec(keys[2], Vec.T_NUM); NewChunk ncs0[] = new NewChunk[nChunks]; NewChunk ncs1[] = new NewChunk[nChunks]; NewChunk ncs2[] = new NewChunk[nChunks]; for( int i=0; i<nChunks; i++ ) { ncs0[i] = new NewChunk(col0,i); ncs1[i] = new NewChunk(col1,i); ncs2[i] = new NewChunk(col2,i); } // inject random pairs into cols 0 and 1 int len = pairs.length*scale2; for( int i=0; i<len; i++ ) { long pair = pairs[R.nextInt(pairs.length)]; int nchk = R.nextInt(nChunks); ncs0[nchk].addNum( (int)(pair>>32),0); ncs1[nchk].addNum( (int)(pair ),0); } // Compute data layout int espc[] = new int[nChunks+1]; for( int i=0; i<nChunks; i++ ) espc[i+1] = espc[i] + ncs0[i].len(); // Compute row numbers into col 2 for( int i=0; i<nChunks; i++ ) for( int j=0; j<ncs0[i].len(); j++ ) ncs2[i].addNum(espc[i]+j,0); Futures fs = new Futures(); for( int i=0; i<nChunks; i++ ) { ncs0[i].close(i,fs); ncs1[i].close(i,fs); ncs2[i].close(i,fs); } Vec vec0 = col0.layout_and_close(fs); Vec vec1 = col1.layout_and_close(fs); Vec vec2 = col2.layout_and_close(fs); fs.blockForPending(); Frame fr = new Frame(Key.<Frame>make("hex"), null, new Vec[]{vec0,vec1,vec2}); DKV.put(fr); return fr; } @Test public void TestSortTimes() throws IOException { Frame fr=null, sorted=null; try { fr = parse_test_file("smalldata/synthetic/sort_crash.csv"); sorted = fr.sort(new int[]{0}); Vec vec = sorted.vec(0); int len = (int)vec.length(); for( int i=1; i<len; i++ ) assertTrue( vec.at8(i-1) <= vec.at8(i) ); } finally { if( fr != null ) fr.delete(); if( sorted != null ) sorted.delete(); } } }