package hex.singlenoderf; import water.*; import water.util.Utils; //import hex.singlenoderf.TreeP; import java.util.ArrayList; import java.util.Arrays; public abstract class Sampling { /** Available sampling strategies. */ public enum Strategy { RANDOM(0); //, int _id; // redundant id private Strategy(int id) { _id = id; } } abstract Data sample(final Data data, long seed, Key modelKey, boolean local_mode); /** Deterministically sample the Data at the bagSizePct. Toss out invalid rows (as-if not sampled), but maintain the sampling rate. */ final static class Random extends Sampling { final double _bagSizePct; final int[] _rowsPerChunks; public Random(double bagSizePct, int[] rowsPerChunks) { _bagSizePct = bagSizePct; _rowsPerChunks = rowsPerChunks; } @Override Data sample(final Data data, long seed, Key modelKey, boolean local_mode) { SpeeDRFModel m = UKV.get(modelKey); int [] sample; sample = sampleFair(data,seed,_rowsPerChunks); // add the remaining rows Arrays.sort(sample); // we want an ordered sample return new Subset(data, sample, 0, sample.length); } /** Roll a fair die for sampling, resetting the random die every numrows. */ private int[] sampleFair(final Data data, long seed, int[] rowsPerChunks ) { // init java.util.Random rand = null; int rows = data.rows(); int size = bagSize(rows,_bagSizePct); int[] sample = MemoryManager.malloc4((int) (size * 1.10)); float f = (float) _bagSizePct; int cnt = 0; // Counter for resetting Random int j = 0; // Number of selected samples int cidx = 0; // Chunks counter // compute for( int i=0; i<rows; i++ ) { if( cnt--==0 ) { /* NOTE: Before changing used generator think about which kind of random generator you need: * if always deterministic or non-deterministic version - see hex.singlenoderf.Utils.get{Deter}RNG */ long chunkSamplingSeed = chunkSampleSeed(seed, i); // DEBUG: System.err.println(seed + " : " + i + " (sampling)"); rand = Utils.getDeterRNG(chunkSamplingSeed); cnt = rowsPerChunks[cidx++]-1; } float randFloat = rand.nextFloat(); if( randFloat < f ) { if( j == sample.length ) sample = Arrays.copyOfRange(sample,0,(int)(1 + sample.length*1.2)); sample[j++] = i; } } return Arrays.copyOf(sample,j); // Trim out bad rows } } /** * ! CRITICAL code ! * This method returns the correct seed based on initial seed and row index. * WARNING : this method is crucial for correct replay of sampling. */ static final long chunkSampleSeed(long seed, int rowIdx) { return seed + ((long)rowIdx<<16); } static final int bagSize( int rows, double bagSizePct ) { int size = (int)(rows * bagSizePct); return (size>0 || rows==0) ? size : 1; } }