package hex.createframe.recipes; import hex.createframe.CreateFrameExecutor; import hex.createframe.CreateFrameRecipe; import hex.createframe.columns.*; import hex.createframe.postprocess.MissingInserterCfps; import hex.createframe.postprocess.ShuffleColumnsCfps; /** * This recipe tries to match the behavior of the original hex.CreateFrame class. */ public class OriginalCreateFrameRecipe extends CreateFrameRecipe<OriginalCreateFrameRecipe> { private int rows = 10000; private int cols = 10; private double real_range = 100; private double categorical_fraction = 0.2; private int factors = 100; private boolean randomize = true; private long value = 0; private double integer_fraction = 0.2; private double time_fraction = 0.0; private double string_fraction = 0.0; private int integer_range = 100; private double binary_fraction = 0.1; private double binary_ones_fraction = 0.02; private double missing_fraction = 0.01; private int response_factors = 2; private boolean positive_response = false; // only for response_factors == 1 private boolean has_response = false; @Override protected void checkParametersValidity() { double total_fraction = integer_fraction + binary_fraction + categorical_fraction + time_fraction + string_fraction; check(total_fraction < 1.00000001, "Integer, binary, categorical, time and string fractions must add up to <= 1"); check(missing_fraction >= 0 && missing_fraction < 1, "Missing fraction must be between 0 and 1"); check(integer_fraction >= 0 && integer_fraction <= 1, "Integer fraction must be between 0 and 1"); check(binary_fraction >= 0 && binary_fraction <= 1, "Binary fraction must be between 0 and 1"); check(time_fraction >= 0 && time_fraction <= 1, "Time fraction must be between 0 and 1"); check(string_fraction >= 0 && string_fraction <= 1, "String fraction must be between 0 and 1"); check(binary_ones_fraction >= 0 && binary_ones_fraction <= 1, "Binary ones fraction must be between 0 and 1"); check(categorical_fraction >= 0 && categorical_fraction <= 1, "Categorical fraction must be between 0 and 1"); check(categorical_fraction == 0 || factors >= 2, "Factors must be larger than 2 for categorical data"); check(response_factors >= 1, "Response factors must be either 1 (real-valued response), or >=2 (factor levels)"); check(response_factors <= 1024, "Response factors must be <= 1024"); check(factors <= 1000000, "Number of factors must be <= 1,000,000"); check(cols > 0 && rows > 0, "Must have number of rows and columns > 0"); check(real_range >= 0, "Real range must be a nonnegative number"); check(integer_range >= 0, "Integer range must be a nonnegative number"); check(dest != null, "Destination frame must have a key"); if (positive_response) check(response_factors == 1, "positive_response can only be requested for real-valued response column"); if (randomize) check(value == 0, "Cannot set data to a constant value if randomize is true"); else { check(!has_response, "Cannot have response column if randomize is false"); check(total_fraction == 0, "Cannot have integer, categorical, string, binary or time columns if randomize is false"); } } @Override protected void buildRecipe(CreateFrameExecutor cfe) { cfe.setSeed(seed); cfe.setNumRows(rows); // Sometimes the client requests, say, 0.3 categorical columns. By the time this number arrives here, it becomes // something like 0.299999999997. If we just multiply by the number of columns (say 10000) and take integer part, // we'd have 2999 columns only -- not what the client expects. This is why we add 0.1 to each count before taking // the floor part. int catcols = (int)(categorical_fraction * cols + 0.1); int intcols = (int)(integer_fraction * cols + 0.1); int bincols = (int)(binary_fraction * cols + 0.1); int timecols = (int)(time_fraction * cols + 0.1); int stringcols = (int)(string_fraction * cols + 0.1); int realcols = cols - catcols - intcols - bincols - timecols - stringcols; // At this point we might have accidentally allocated too many columns -- in such case adjust their counts. if (realcols < 0 && catcols > 0) { catcols--; realcols++; } if (realcols < 0 && intcols > 0) { intcols--; realcols++; } if (realcols < 0 && bincols > 0) { bincols--; realcols++; } if (realcols < 0 && timecols > 0) { timecols--; realcols++; } if (realcols < 0 && stringcols > 0) { stringcols--; realcols++; } assert catcols >= 0 && intcols >= 0 && bincols >= 0 && realcols >= 0 && timecols >= 0 && stringcols >= 0; // Create response column if (has_response) { if (response_factors == 1) cfe.addColumnMaker(new RealColumnCfcm("response", positive_response? 0 : -real_range, real_range)); else cfe.addColumnMaker(new CategoricalColumnCfcm("response", response_factors)); } // Create "feature" columns if (randomize) { int j = 0; for (int i = 0; i < intcols; i++) cfe.addColumnMaker(new IntegerColumnCfcm("C" + (++j), -integer_range, integer_range)); for (int i = 0; i < realcols; i++) cfe.addColumnMaker(new RealColumnCfcm("C" + (++j), -real_range, real_range)); for (int i = 0; i < catcols; i++) cfe.addColumnMaker(new CategoricalColumnCfcm("C" + (++j), factors)); for (int i = 0; i < bincols; i++) cfe.addColumnMaker(new BinaryColumnCfcm("C" + (++j), binary_ones_fraction)); for (int i = 0; i < timecols; i++) cfe.addColumnMaker(new TimeColumnCfcm("C" + (++j), 0, 50L * 365 * 24 * 3600 * 1000)); // 1970...2020 for (int i = 0; i < stringcols; i++) cfe.addColumnMaker(new StringColumnCfcm("C" + (++j), 8)); } else { assert catcols + intcols + bincols + timecols + stringcols == 0; for (int i = 0; i < realcols; i++) cfe.addColumnMaker(new RealColumnCfcm("C" + (i+1), value, value)); } // Add post-processing steps cfe.addPostprocessStep(new MissingInserterCfps(missing_fraction)); cfe.addPostprocessStep(new ShuffleColumnsCfps(true, true)); } }