package hex.schemas;
import hex.ScoreKeeper;
import hex.grid.HyperSpaceSearchCriteria;
import water.api.API;
import water.api.schemas3.SchemaV3;
import water.exceptions.H2OIllegalArgumentException;
/**
* Search criteria for a hyperparameter search including directives for how to search and
* when to stop the search.
*/
public class HyperSpaceSearchCriteriaV99<I extends HyperSpaceSearchCriteria, S extends HyperSpaceSearchCriteriaV99<I,S>>
extends SchemaV3<I, S> {
@API(help = "Hyperparameter space search strategy.", required = true, values = { "Unknown", "Cartesian", "RandomDiscrete" }, direction = API.Direction.INOUT)
public HyperSpaceSearchCriteria.Strategy strategy;
// TODO: add a factory which accepts a Strategy and calls the right constructor
/**
* Search criteria for an exhaustive Cartesian hyperparameter search.
*/
public static class CartesianSearchCriteriaV99 extends HyperSpaceSearchCriteriaV99<HyperSpaceSearchCriteria.CartesianSearchCriteria, CartesianSearchCriteriaV99> {
public CartesianSearchCriteriaV99() {
strategy = HyperSpaceSearchCriteria.Strategy.Cartesian;
}
}
/**
* Search criteria for random hyperparameter search using hyperparameter values given by
* lists. Includes directives for how to search and when to stop the search.
*/
public static class RandomDiscreteValueSearchCriteriaV99 extends HyperSpaceSearchCriteriaV99<HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria, RandomDiscreteValueSearchCriteriaV99> {
public RandomDiscreteValueSearchCriteriaV99() {
strategy = HyperSpaceSearchCriteria.Strategy.RandomDiscrete;
}
public RandomDiscreteValueSearchCriteriaV99(long seed, int max_models, int max_runtime_secs) {
strategy = HyperSpaceSearchCriteria.Strategy.RandomDiscrete;
this.seed = seed;
this.max_models = max_models;
this.max_runtime_secs = max_runtime_secs;
}
@API(help = "Seed for random number generator; set to a value other than -1 for reproducibility.", required = false, direction = API.Direction.INOUT)
public long seed;
@API(help = "Maximum number of models to build (optional).", required = false, direction = API.Direction.INOUT)
public int max_models;
@API(help = "Maximum time to spend building models (optional).", required = false, direction = API.Direction.INOUT)
public double max_runtime_secs;
@API(help = "Early stopping based on convergence of stopping_metric. Stop if simple moving average of length k of the stopping_metric does not improve for k:=stopping_rounds scoring events (0 to disable)", level = API.Level.secondary, direction=API.Direction.INOUT, gridable = true)
public int stopping_rounds;
@API(help = "Metric to use for early stopping (AUTO: logloss for classification, deviance for regression)", values = {"AUTO", "deviance", "logloss", "MSE", "RMSE","MAE","RMSLE", "AUC", "lift_top_group", "misclassification", "mean_per_class_error"}, level = API.Level.secondary, direction=API.Direction.INOUT, gridable = true)
public ScoreKeeper.StoppingMetric stopping_metric;
@API(help = "Relative tolerance for metric-based stopping criterion Relative tolerance for metric-based stopping criterion (stop if relative improvement is not at least this much)", level = API.Level.secondary, direction=API.Direction.INOUT, gridable = true)
public double stopping_tolerance;
}
/**
* Fill with the default values from the corresponding Iced object.
*/
public S fillWithDefaults() {
HyperSpaceSearchCriteria defaults = null;
if (HyperSpaceSearchCriteria.Strategy.Cartesian == strategy) {
defaults = new HyperSpaceSearchCriteria.CartesianSearchCriteria();
} else if (HyperSpaceSearchCriteria.Strategy.RandomDiscrete == strategy) {
defaults = new HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria();
} else {
throw new H2OIllegalArgumentException("search_criteria.strategy", strategy.toString());
}
fillFromImpl((I)defaults);
return (S) this;
}
}