package hex; import org.junit.BeforeClass; import hex.drf.DRF; import hex.drf.DRF.DRFModel; import java.util.Random; import org.junit.Test; import water.*; public class DHistTest extends TestUtil { @BeforeClass public static void stall() { stall_till_cloudsize(1); } @Test public void testDBinom() { DRF drf = new DRF(); Key destTrain = Key.make("data.hex"); DRFModel model = null; try { // Configure DRF drf.source = parseFrame(destTrain, "../smalldata/histogram_test/alphabet_cattest.csv"); drf.response = drf.source.vecs()[1]; drf.classification = true; drf.ntrees = 100; drf.max_depth = 5; // = interaction.depth drf.min_rows = 10; // = nodesize drf.nbins = 100; drf.destination_key = Key.make("DRF_model_dhist.hex"); // Invoke DRF and block till the end drf.invoke(); // Get the model model = UKV.get(drf.dest()); } finally { drf.source.delete(); drf.remove(); if(model != null) model.delete(); // Remove the model } } public String[] sample(String[] levels, int n, long seed) { Random rand = new Random(seed); int ncat = levels.length; String[] samp = new String[n]; for(int i = 0; i < n; i++) samp[i] = levels[rand.nextInt(ncat)]; return samp; } private double prob(String lev) { if(lev == "A") return 0.8; if(lev == "B") return 0.6; if(lev == "C") return 0.4; if(lev == "D") return 0.2; return 0.5; } }