package hex; import hex.glm.GLM2; import hex.glm.GLMModel; import hex.glm.GLMParams; import org.junit.BeforeClass; import org.junit.Test; import water.*; import water.fvec.Frame; import water.util.Log; import java.util.Random; public class GLMRandomTest extends TestUtil { @BeforeClass public static void stall() { stall_till_cloudsize(JUnitRunnerDebug.NODES); } private static class GLM2Test extends GLM2 { public void invokeServe() { Response r = serve(); if(r.error() != null) throw new IllegalArgumentException("Got error " + r.error()); _fjtask.join(); } public static void runFraction(float fraction) throws Throwable { long seed = new Random().nextLong(); Log.info("GLMRadnomTest: seed = " + seed); Random rng = new Random(seed); int testcount = 0; int count = 0; int jobId = 0; final Key dest = Key.make("DEST"); for(boolean offset:new boolean[]{true,false}){ for(boolean intercept:new boolean[]{true,false}) { for (int rows : new int[]{ 10, 100, // 1000, }) { for (int cols : new int[]{ 1, 10, // 100, }) { for (float categorical_fraction : intercept ? new float[]{ 0, 0.1f, 1 } : new float[]{0}) { for (int factors : new int[]{ 2, 10, }) { for (int response_factors : new int[]{ 1, //regression 2, //binomial }) { for (boolean positive_response : new boolean[]{ true, false }) { for (int max_iter : new int[]{ 1, 10 }) { for (boolean standardize : new boolean[]{ true, false, }) { for (int n_folds : new int[]{ 0, 3, }) { for (GLMParams.Family family : new GLMParams.Family[]{ GLMParams.Family.gamma, GLMParams.Family.gaussian, // GLMParams.Family.tweedie, tweedie is unstable now GLMParams.Family.binomial, GLMParams.Family.poisson, }) { if (response_factors != 2 && family == GLMParams.Family.binomial) continue; if (!positive_response && family == GLMParams.Family.gamma) continue; if (!positive_response && family == GLMParams.Family.tweedie) continue; for (GLMParams.Link link : new GLMParams.Link[]{ GLMParams.Link.family_default, // GLMParams.Link.identity, // GLMParams.Link.inverse, // GLMParams.Link.log, // GLMParams.Link.logit, // GLMParams.Link.tweedie, }) { switch (family) { case gaussian: if (link != GLMParams.Link.identity && link != GLMParams.Link.log && link != GLMParams.Link.inverse) continue; break; case binomial: if (link != GLMParams.Link.logit && link != GLMParams.Link.log) continue; break; case poisson: if (link != GLMParams.Link.log && link != GLMParams.Link.identity) continue; break; case gamma: if (link != GLMParams.Link.inverse && link != GLMParams.Link.log && link != GLMParams.Link.identity) continue; break; case tweedie: if (link != GLMParams.Link.tweedie) continue; break; } for (double tweedie_variance_power : new double[]{ 0 }) { for (double[] alpha : new double[][]{ new double[]{1e-5}, new double[]{1}, new double[]{0, 0.5, 1}, }) { for (double[] lambda : new double[][]{ new double[]{1e-4}, new double[]{1e-3, 1e-3}, }) { for (double beta_epsilon : new double[]{ 0, 1e-4, }) { for (boolean higher_accuracy : new boolean[]{ true, false, }) { for (boolean use_all_factor_levels : new boolean[]{ true, false, }) { for (boolean lambda_search : new boolean[]{ true, false, }) { for (boolean strong_rules : new boolean[]{ true, false, }) { for (int max_predictors : new int[]{ -1, 3, }) { for (int nlambdas : new int[]{ 1, 5, 50 }) { for (double lambda_min_ratio : new double[]{ 1e-2, 1e-4, }) { for (double prior : new double[]{ -1, 0.001 }) { for (boolean variable_importances : new boolean[]{ true, false, }) { count++; if (fraction < rng.nextFloat()) continue; CreateFrame cf = new CreateFrame(); cf.key = "random"; cf.rows = rows; cf.cols = cols; cf.categorical_fraction = categorical_fraction; cf.integer_fraction = 1 - categorical_fraction; cf.factors = factors; cf.response_factors = response_factors; cf.positive_response = positive_response; cf.seed = seed; cf.serve(); Frame frame = UKV.get(Key.make(cf.key)); Log.info("**************************)"); Log.info("Starting test #" + count); Log.info("**************************)"); { GLM2Test p = new GLM2Test(); p.job_key = Key.make("RandomGLM_" + jobId++); p.destination_key = dest; p.source = frame; p.response = frame.vecs()[0]; //response is always the first column p.max_iter = max_iter; p.standardize = standardize; p.n_folds = n_folds; p.family = family; p.link = link; p.tweedie_variance_power = tweedie_variance_power; p.alpha = alpha; p.lambda = lambda; p.beta_epsilon = beta_epsilon; p.higher_accuracy = higher_accuracy; p.use_all_factor_levels = use_all_factor_levels; p.lambda_search = lambda_search; p.strong_rules = strong_rules; p.max_predictors = max_predictors; p.nlambdas = nlambdas; p.lambda_min_ratio = lambda_min_ratio; p.prior = prior; p.variable_importances = variable_importances; p.MAX_ITERATIONS_PER_LAMBDA = 5; p.intercept = intercept; p.offset = (offset && frame.numCols() > 2)?frame.vec(1):null; try { p.invokeServe(); assert p._done; if (p.alpha.length > 1) new GLMGrid.DeleteGridTsk(null, p.destination_key).submitTask(); else new GLMModel.DeleteModelTask(null, p.destination_key).submitTask(); System.out.println("TEST DONE"); } catch (DException.DistributedException dex) { if (dex.getMessage().contains("IllegalArgument")) Log.info("Skipping invalid combination of arguments."); else throw new RuntimeException(dex); } catch (IllegalArgumentException t) { Log.info("Skipping invalid combination of arguments."); // accept IllegalArgumentException, but nothing else } finally { frame.delete(); } } Log.info("Parameters combination " + count + ": PASS"); testcount++; } } } } } } } } } } } } } } } } } } } } } } } } } Log.info("\n\n============================================="); Log.info("Tested " + testcount + " out of " + count + " parameter combinations."); Log.info("============================================="); } } } public static class Long extends GLMRandomTest { @Test public void run() throws Throwable { GLM2Test.runFraction(0.1f); } } public static class Short extends GLMRandomTest { @Test public void run() throws Throwable { GLM2Test.runFraction(1e-6f); } } }