package hex.schemas; import hex.Model; import hex.grid.Grid; import water.H2O; import water.Key; import water.api.*; import water.api.schemas3.JobV3; import water.api.schemas3.KeyV3; import water.api.schemas3.ModelParametersSchemaV3; import water.api.schemas3.SchemaV3; import water.exceptions.H2OIllegalArgumentException; import water.util.IcedHashMap; import java.util.List; import java.util.Map; import java.util.Properties; /** * This is a common grid search schema composed of two parameters: default parameters for a builder * and hyper parameters which are given as a mapping from parameter name to list of possible * values. * <p> * TODO: this needs a V99 subclass for bindings generation. * * @param <G> a specific implementation type for GridSearch holding results of grid search (model list) * @param <S> self type * @param <MP> actual model parameters type * @param <P> a specific model builder parameters schema, since we cannot derive it from P */ public class GridSearchSchema<G extends Grid<MP>, S extends GridSearchSchema<G, S, MP, P>, MP extends Model.Parameters, P extends ModelParametersSchemaV3> extends SchemaV3<G, S> { // // Inputs // @API(help = "Basic model builder parameters.", direction = API.Direction.INPUT) public P parameters; @API(help = "Grid search parameters.", direction = API.Direction.INOUT) public IcedHashMap<String, Object[]> hyper_parameters; @API(help = "Destination id for this grid; auto-generated if not specified.", required = false, direction = API.Direction.INOUT) public KeyV3.GridKeyV3 grid_id; @API(help="Hyperparameter search criteria, including strategy and early stopping directives. If it is not given, exhaustive Cartesian is used.", required = false, direction = API.Direction.INOUT) public HyperSpaceSearchCriteriaV99 search_criteria; // // Outputs // @API(help = "Number of all models generated by grid search.", direction = API.Direction.OUTPUT) public int total_models; @API(help = "Job Key.", direction = API.Direction.OUTPUT) public JobV3 job; @Override public S fillFromParms(Properties parms) { if( parms.containsKey("hyper_parameters") ) { Map<String,Object> m = water.util.JSONUtils.parse(parms.getProperty("hyper_parameters")); // Convert lists and singletons into arrays for (Map.Entry<String, Object> e : m.entrySet()) { Object o = e.getValue(); Object[] o2 = o instanceof List ? ((List) o).toArray() : new Object[]{o}; hyper_parameters.put(e.getKey(),o2); } parms.remove("hyper_parameters"); } if( parms.containsKey("search_criteria") ) { Properties p = water.util.JSONUtils.parseToProperties(parms.getProperty("search_criteria")); if (! p.containsKey("strategy")) { throw new H2OIllegalArgumentException("search_criteria.strategy", "null"); } // TODO: move this into a factory method in HyperSpaceSearchCriteriaV99 String strategy = (String)p.get("strategy"); if ("Cartesian".equals(strategy)) { search_criteria = new HyperSpaceSearchCriteriaV99.CartesianSearchCriteriaV99(); } else if ("RandomDiscrete".equals(strategy)) { search_criteria = new HyperSpaceSearchCriteriaV99.RandomDiscreteValueSearchCriteriaV99(); if (p.containsKey("max_runtime_secs") && Double.parseDouble((String) p.get("max_runtime_secs"))<0) { throw new H2OIllegalArgumentException("max_runtime_secs must be >= 0 (0 for unlimited time)", strategy); } if (p.containsKey("max_models") && Integer.parseInt((String) p.get("max_models"))<0) { throw new H2OIllegalArgumentException("max_models must be >= 0 (0 for all models)", strategy); } } else { throw new H2OIllegalArgumentException("search_criteria.strategy", strategy); } search_criteria.fillWithDefaults(); search_criteria.fillFromParms(p); parms.remove("search_criteria"); } else { // Fall back to Cartesian if there's no search_criteria specified. search_criteria = new HyperSpaceSearchCriteriaV99.CartesianSearchCriteriaV99(); } if (parms.containsKey("grid_id")) { grid_id = new KeyV3.GridKeyV3(Key.<Grid>make(parms.getProperty("grid_id"))); parms.remove("grid_id"); } // Do not check validity of parameters, GridSearch is tolerant of bad // parameters (on purpose, many hyper-param points in the grid might be // illegal for whatever reason). this.parameters.fillFromParms(parms, false); return (S) this; } @Override public S fillFromImpl(G impl) { throw H2O.unimpl(); //S s = super.fillFromImpl(impl); //s.parameters = createParametersSchema(); //s.parameters.fillFromImpl((MP) parameters.createImpl()); //return s; } }