package hex.grid; import hex.ScoreKeeper; import water.Iced; import water.fvec.Frame; /** * Search criteria for a hyperparameter search including directives for how to search and * when to stop the search. */ public class HyperSpaceSearchCriteria extends Iced { public enum Strategy { Unknown, Cartesian, RandomDiscrete } // search strategy public final Strategy _strategy; public final Strategy strategy() { return _strategy; } public ScoreKeeper.StoppingMetric stopping_metric() { return ScoreKeeper.StoppingMetric.AUTO; } // TODO: add a factory which accepts a Strategy and calls the right constructor public HyperSpaceSearchCriteria(Strategy strategy) { this._strategy = strategy; } /** * Search criteria for an exhaustive Cartesian hyperparameter search. */ public static final class CartesianSearchCriteria extends HyperSpaceSearchCriteria { public CartesianSearchCriteria() { super(Strategy.Cartesian); } } /** * Search criteria for a hyperparameter search including directives for how to search and * when to stop the search. * <p> * NOTE: client ought to call set_default_stopping_tolerance_for_frame(Frame) to get a reasonable stopping tolerance, especially for small N. */ public static final class RandomDiscreteValueSearchCriteria extends HyperSpaceSearchCriteria { private long _seed = -1; // -1 means true random ///////////////////// // stopping criteria: private int _max_models = 0; private double _max_runtime_secs = 0; private int _stopping_rounds = 0; private ScoreKeeper.StoppingMetric _stopping_metric = ScoreKeeper.StoppingMetric.AUTO; public double _stopping_tolerance = 0.001; /** Seed for the random choices of hyperparameter values. Set to a value other than -1 to get a repeatable pseudorandom sequence. */ public long seed() { return _seed; } /** Max number of models to build. */ public int max_models() { return _max_models; } /** * Max runtime for the entire grid, in seconds. Set to 0 to disable. Can be combined with <i>max_runtime_secs</i> in the model parameters. If * <i>max_runtime_secs</i> is not set in the model parameters then each model build is launched with a limit equal to * the remainder of the grid time. If <i>max_runtime_secs</i> <b>is</b> set in the mode parameters each build is launched * with a limit equal to the minimum of the model time limit and the remaining time for the grid. */ public double max_runtime_secs() { return _max_runtime_secs; } /** * Early stopping based on convergence of stopping_metric. * Stop if simple moving average of the stopping_metric does not improve by stopping_tolerance for * k scoring events. * Can only trigger after at least 2k scoring events. Use 0 to disable. */ public int stopping_rounds() { return _stopping_rounds; } /** Metric to use for convergence checking; only for _stopping_rounds > 0 */ public ScoreKeeper.StoppingMetric stopping_metric() { return _stopping_metric; } /** Relative tolerance for metric-based stopping criterion: stop if relative improvement is not at least this much. */ public double stopping_tolerance() { return _stopping_tolerance; } /** Calculate a reasonable stopping tolerance for the Frame. * Currently uses only the NA percentage and nrows, but later * can take into account the response distribution, response variance, etc. * <p> * <pre>1/Math.sqrt(frame.naFraction() * frame.numRows())</pre> */ public static double default_stopping_tolerance_for_frame(Frame frame) { return Math.min(0.05, Math.max(0.001, 1/Math.sqrt(frame.naFraction() * frame.numRows()))); } public void set_default_stopping_tolerance_for_frame(Frame frame) { _stopping_tolerance = default_stopping_tolerance_for_frame(frame); } public RandomDiscreteValueSearchCriteria() { super(Strategy.RandomDiscrete); } public void set_seed(long _seed) { this._seed = _seed; } public void set_max_models(int _max_models) { this._max_models = _max_models; } public void set_max_runtime_secs(double _max_runtime_secs) { this._max_runtime_secs = _max_runtime_secs; } public void set_stopping_rounds(int _stopping_rounds) { this._stopping_rounds = _stopping_rounds; } public void set_stopping_metric(ScoreKeeper.StoppingMetric _stopping_metric) { this._stopping_metric = _stopping_metric; } public void set_stopping_tolerance(double _stopping_tolerance) { this._stopping_tolerance = _stopping_tolerance; } } }