package hex.grid; import hex.*; import hex.grid.HyperSpaceWalker.BaseWalker; import water.*; import water.exceptions.H2OIllegalArgumentException; import water.fvec.Frame; import water.util.Log; import water.util.PojoUtils; import java.io.PrintWriter; import java.io.StringWriter; import java.text.DecimalFormat; import java.text.NumberFormat; import java.util.Map; /** * Grid search job. * * This job represents a generic interface to launch "any" hyper space * search. It triggers sub-jobs for each point in hyper space. It produces * <code>Grid</code> object which contains a list of build models. A triggered * model builder job can fail! * * Grid search is parametrized by hyper space walk strategy ({@link * hex.grid.HyperSpaceWalker} which defines how the space of hyper parameters * is traversed. * * The job is started by the <code>startGridSearch</code> method which create a new grid search, put * representation of Grid into distributed KV store, and for each parameter in hyper space of * possible parameters, it launches a separated model building job. The launch of jobs is sequential * and blocking. So after finish the last model, whole grid search job is done as well. * * By default, the grid search invokes cartezian grid search, but it can be * modified by passing explicit hyper space walk strategy via the * {@link #startGridSearch(Key, HyperSpaceWalker)} method. * * If any of forked jobs fails then the failure is ignored, and grid search * normally continue in traversing the hyper space. * * Typical usage from Java is: * <pre>{@code * // Create initial parameters and fill them by references to data * GBMModel.GBMParameters params = new GBMModel.GBMParameters(); * params._train = fr._key; * params._response_column = "cylinders"; * * // Define hyper-space to search * HashMap<String,Object[]> hyperParms = new HashMap<>(); * hyperParms.put("_ntrees", new Integer[]{1, 2}); * hyperParms.put("_distribution",new DistributionFamily[] {DistributionFamily.multinomial}); * hyperParms.put("_max_depth",new Integer[]{1,2,5}); * hyperParms.put("_learn_rate",new Float[]{0.01f,0.1f,0.3f}); * * // Launch grid search job creating GBM models * GridSearch gridSearchJob = GridSearch.startGridSearch(params, hyperParms, GBM_MODEL_FACTORY); * * // Block till the end of the job and get result * Grid grid = gridSearchJob.get() * * // Get built models * Model[] models = grid.getModels() * }</pre> * * @see hex.grid.HyperSpaceWalker * @see #startGridSearch(Key, HyperSpaceWalker) */ public final class GridSearch<MP extends Model.Parameters> extends Keyed<GridSearch> { public final Key<Grid> _result; public final Job<Grid> _job; /** Walks hyper space and for each point produces model parameters. It is * used only locally to fire new model builders. */ private final transient HyperSpaceWalker<MP, ?> _hyperSpaceWalker; private GridSearch(Key<Grid> gkey, HyperSpaceWalker<MP, ?> hyperSpaceWalker) { assert hyperSpaceWalker != null : "Grid search needs to know how to walk around hyper space!"; _hyperSpaceWalker = hyperSpaceWalker; _result = gkey; String algoName = hyperSpaceWalker.getParams().algoName(); _job = new Job<>(gkey, Grid.class.getName(), algoName + " Grid Search"); // Note: do not validate parameters of created model builders here! // Leave it to launch time, and just mark the corresponding model builder job as failed. } Job<Grid> start() { final long gridSize = _hyperSpaceWalker.getMaxHyperSpaceSize(); Log.info("Starting gridsearch: estimated size of search space = " + gridSize); // Create grid object and lock it // Creation is done here, since we would like make sure that after leaving // this function the grid object is in DKV and accessible. final Grid<MP> grid; Keyed keyed = DKV.getGet(_result); if (keyed != null) { if (! (keyed instanceof Grid)) throw new H2OIllegalArgumentException("Name conflict: tried to create a Grid using the ID of a non-Grid object that's already in H2O: " + _job._result + "; it is a: " + keyed.getClass()); grid = (Grid) keyed; Frame specTrainFrame = _hyperSpaceWalker.getParams().train(); Frame oldTrainFrame = grid.getTrainingFrame(); if (oldTrainFrame != null && !specTrainFrame._key.equals(oldTrainFrame._key) || oldTrainFrame != null && specTrainFrame.checksum() != oldTrainFrame.checksum()) throw new H2OIllegalArgumentException("training_frame", "grid", "Cannot append new models to a grid with different training input"); grid.write_lock(_job); } else { grid = new Grid<>(_result, _hyperSpaceWalker.getParams(), _hyperSpaceWalker.getHyperParamNames(), _hyperSpaceWalker.getParametersBuilderFactory().getFieldNamingStrategy()); grid.delete_and_lock(_job); } Model model = null; HyperSpaceWalker.HyperSpaceIterator<MP> it = _hyperSpaceWalker.iterator(); long gridWork=0; if (gridSize > 0) {//if total grid space is known, walk it all and count up models to be built (not subject to time-based or converge-based early stopping) int count=0; while (it.hasNext(model) && (it.max_models() > 0 && count++ < it.max_models())) { //only walk the first max_models models, if specified try { Model.Parameters parms = it.nextModelParameters(model); gridWork += (parms._nfolds > 0 ? (parms._nfolds+1/*main model*/) : 1) *parms.progressUnits(); } catch(Throwable ex) { //swallow invalid combinations } } } else { //TODO: Future totally unbounded search: need a time-based progress bar gridWork = Long.MAX_VALUE; } it.reset(); // Install this as job functions return _job.start(new H2O.H2OCountedCompleter() { @Override public void compute2() { gridSearch(grid); tryComplete(); } }, gridWork, it.max_runtime_secs()); } /** * Returns expected number of models in resulting Grid object. * * The number can differ from final number of models due to visiting duplicate points in hyper * space. * * @return expected number of models produced by this grid search */ public long getModelCount() { return _hyperSpaceWalker.getMaxHyperSpaceSize(); } /** * Invokes grid search based on specified hyper space walk strategy. * * It updates passed grid object in distributed store. * * @param grid grid object to save results; grid already locked */ private void gridSearch(Grid<MP> grid) { Model model = null; // Prepare nice model key and override default key by appending model counter //String protoModelKey = _hyperSpaceWalker.getParams()._model_id == null // ? grid._key + "_model_" // : _hyperSpaceWalker.getParams()._model_id.toString() + H2O.calcNextUniqueModelId("") + "_"; String protoModelKey = grid._key + "_model_"; try { // Get iterator to traverse hyper space HyperSpaceWalker.HyperSpaceIterator<MP> it = _hyperSpaceWalker.iterator(); // Number of traversed model parameters int counter = grid.getModelCount(); while (it.hasNext(model)) { if(_job.stop_requested() ) return; // Handle end-user cancel request double max_runtime_secs = it.max_runtime_secs(); double time_remaining_secs = Double.MAX_VALUE; if (max_runtime_secs > 0) { time_remaining_secs = it.time_remaining_secs(); if (time_remaining_secs < 0) { Log.info("Grid max_runtime_secs of " + max_runtime_secs + " secs has expired; stopping early."); return; } } MP params; try { // Get parameters for next model params = it.nextModelParameters(model); // Sequential model building, should never propagate // exception up, just mark combination of model parameters as wrong // Do we need to limit the model build time? if (max_runtime_secs > 0) { Log.info("Grid time is limited to: " + max_runtime_secs + " for grid: " + grid._key + ". Remaining time is: " + time_remaining_secs); double scale = params._nfolds > 0 ? params._nfolds+1 : 1; //remaining time per cv model is less if (params._max_runtime_secs == 0) { // unlimited params._max_runtime_secs = time_remaining_secs/scale; Log.info("Due to the grid time limit, changing model max runtime to: " + params._max_runtime_secs + " secs."); } else { double was = params._max_runtime_secs; params._max_runtime_secs = Math.min(params._max_runtime_secs, time_remaining_secs/scale); Log.info("Due to the grid time limit, changing model max runtime from: " + was + " secs to: " + params._max_runtime_secs + " secs."); } } try { ScoringInfo scoringInfo = new ScoringInfo(); scoringInfo.time_stamp_ms = System.currentTimeMillis(); //// build the model! model = buildModel(params, grid, counter++, protoModelKey); if (model!=null) { model.fillScoringInfo(scoringInfo); grid.setScoringInfos(ScoringInfo.prependScoringInfo(scoringInfo, grid.getScoringInfos())); ScoringInfo.sort(grid.getScoringInfos(), _hyperSpaceWalker.search_criteria().stopping_metric()); // Currently AUTO for Cartesian and user-specified for RandomDiscrete } } catch (RuntimeException e) { // Catch everything if (!Job.isCancelledException(e)) { StringWriter sw = new StringWriter(); PrintWriter pw = new PrintWriter(sw); e.printStackTrace(pw); Log.warn("Grid search: model builder for parameters " + params + " failed! Exception: ", e, sw.toString()); } grid.appendFailedModelParameters(params, e); } } catch (IllegalArgumentException e) { Log.warn("Grid search: construction of model parameters failed! Exception: ", e); // Model parameters cannot be constructed for some reason it.modelFailed(model); Object[] rawParams = it.getCurrentRawParameters(); grid.appendFailedModelParameters(rawParams, e); } finally { // Update progress by 1 increment _job.update(1); // Always update grid in DKV after model building attempt grid.update(_job); } // finally if (model != null && grid.getScoringInfos() != null && // did model build and scoringInfo creation succeed? _hyperSpaceWalker.stopEarly(model, grid.getScoringInfos())) { Log.info("Convergence detected based on simple moving average of the loss function. Grid building completed."); break; } } // while (it.hasNext(model)) Log.info("For grid: " + grid._key + " built: " + grid.getModelCount() + " models."); } finally { grid.unlock(_job); } } /** * Build a model based on specified parameters and save it to resulting Grid object. * * Returns a model run with these parameters, typically built on demand and cached - expected to * be an expensive operation. If the model in question is "in progress", a 2nd build will NOT be * kicked off. This is a blocking call. * * If a new model is created, then the Grid object is updated in distributed store. If a model for * given parameters already exists, it is directly returned without updating the Grid object. If * model building fails then the Grid object is not updated and the method returns * <code>null</code>. * * @param params parameters for a new model * @param grid grid object holding created models * @param paramsIdx index of generated model parameter * @param protoModelKey prototype of model key * @return return a new model if it does not exist */ private Model buildModel(final MP params, Grid<MP> grid, int paramsIdx, String protoModelKey) { // Make sure that the model is not yet built (can be case of duplicated hyper parameters). // We first look in the grid _models cache, then we look in the DKV. // FIXME: get checksum here since model builder will modify instance of params!!! final long checksum = params.checksum(); Key<Model> key = grid.getModelKey(checksum); if (key != null) { if (DKV.get(key) == null) { // We know about a model that's been removed; rebuild. Log.info("GridSearch.buildModel(): model with these parameters was built but removed, rebuilding; checksum: " + checksum); } else { Log.info("GridSearch.buildModel(): model with these parameters already exists, skipping; checksum: " + checksum); return key.get(); } } // Is there a model with the same params in the DKV? final Key<Model>[] modelKeys = KeySnapshot.globalSnapshot().filter(new KeySnapshot.KVFilter() { @Override public boolean filter(KeySnapshot.KeyInfo k) { return Value.isSubclassOf(k._type, Model.class) && ((Model)k._key.get())._parms.checksum() == checksum; } }).keys(); if (modelKeys.length > 0) { grid.putModel(checksum, modelKeys[0]); return modelKeys[0].get(); } // Modify model key to have nice version with counter // Note: Cannot create it before checking the cache since checksum would differ for each model Key<Model> result = Key.make(protoModelKey + paramsIdx); // Build a new model // THIS IS BLOCKING call since we do not have enough information about free resources // FIXME: we should allow here any launching strategy (not only sequential) Model m = (Model)startBuildModel(result,params, grid).dest().get(); grid.putModel(checksum, result); return m; } /** * Triggers model building process but do not block on it. * * @param params parameters for a new model * @param grid resulting grid object * @return A Future of a model run with these parameters, typically built on demand and not cached * - expected to be an expensive operation. If the model in question is "in progress", a 2nd * build will NOT be kicked off. This is a non-blocking call. */ private ModelBuilder startBuildModel(Key result, MP params, Grid<MP> grid) { if (grid.getModel(params) != null) return null; ModelBuilder mb = ModelBuilder.make(params.algoName(), _job, result); mb._parms = params; mb.trainModelNested(null); return mb; } /** * Defines a key for a new Grid object holding results of grid search. * * @return a grid key for a particular modeling class and frame. * @throws java.lang.IllegalArgumentException if frame is not saved to distributed store. */ protected static Key<Grid> gridKeyName(String modelName, Frame fr) { if (fr == null || fr._key == null) { throw new IllegalArgumentException("The frame being grid-searched over must have a Key"); } return Key.make("Grid_" + modelName + "_" + fr._key.toString() + H2O.calcNextUniqueModelId("")); } /** * Start a new grid search job. This is the method that gets called by GridSearchHandler.do_train(). * <p> * This method launches a "classical" grid search traversing cartesian grid of parameters * point-by-point, <b>or</b> a random hyperparameter search, depending on the value of the <i>strategy</i> * parameter. * * @param destKey A key to store result of grid search under. * @param params Default parameters for model builder. This object is used to create * a specific model parameters for a combination of hyper parameters. * @param hyperParams A set of arrays of hyper parameter values, used to specify a simple * fully-filled-in grid search. * @param paramsBuilderFactory defines a strategy for creating a new model parameters based on * common parameters and list of hyper-parameters * @return GridSearch Job, with models run with these parameters, built as needed - expected to be * an expensive operation. If the models in question are "in progress", a 2nd build will NOT be * kicked off. This is a non-blocking call. */ public static <MP extends Model.Parameters> Job<Grid> startGridSearch( final Key<Grid> destKey, final MP params, final Map<String, Object[]> hyperParams, final ModelParametersBuilderFactory<MP> paramsBuilderFactory, final HyperSpaceSearchCriteria search_criteria) { return startGridSearch(destKey, BaseWalker.WalkerFactory.create(params, hyperParams, paramsBuilderFactory, search_criteria)); } /** * Start a new grid search job. * * <p>This method launches "classical" grid search traversing cartesian grid of parameters * point-by-point. For more advanced hyperparameter search behavior call the referenced method. * * @param destKey A key to store result of grid search under. * @param params Default parameters for model builder. This object is used to create a * specific model parameters for a combination of hyper parameters. * @param hyperParams A set of arrays of hyper parameter values, used to specify a simple * fully-filled-in grid search. * @return GridSearch Job, with models run with these parameters, built as needed - expected to be * an expensive operation. If the models in question are "in progress", a 2nd build will NOT be * kicked off. This is a non-blocking call. * * @see #startGridSearch(Key, Model.Parameters, Map, ModelParametersBuilderFactory, HyperSpaceSearchCriteria) */ public static <MP extends Model.Parameters> Job<Grid> startGridSearch(final Key<Grid> destKey, final MP params, final Map<String, Object[]> hyperParams) { return startGridSearch( destKey, params, hyperParams, new SimpleParametersBuilderFactory<MP>(), new HyperSpaceSearchCriteria.CartesianSearchCriteria()); } /** * Start a new grid search job. <p> This method launches any grid search traversing space of hyper * parameters based on specified strategy. * * @param destKey A key to store result of grid search under. * @param hyperSpaceWalker defines a strategy for traversing a hyper space. The object itself * holds definition of hyper space. * @return GridSearch Job, with models run with these parameters, built as needed - expected to be * an expensive operation. If the models in question are "in progress", a 2nd build will NOT be * kicked off. This is a non-blocking call. */ public static <MP extends Model.Parameters> Job<Grid> startGridSearch( final Key<Grid> destKey, final HyperSpaceWalker<MP, ?> hyperSpaceWalker) { // Compute key for destination object representing grid MP params = hyperSpaceWalker.getParams(); Key<Grid> gridKey = destKey != null ? destKey : gridKeyName(params.algoName(), params.train()); // Start the search return new GridSearch(gridKey, hyperSpaceWalker).start(); } /** * The factory is producing a parameters builder which uses reflection to setup field values. * * @param <MP> type of model parameters object */ public static class SimpleParametersBuilderFactory<MP extends Model.Parameters> implements ModelParametersBuilderFactory<MP> { @Override public ModelParametersBuilder<MP> get(MP initialParams) { return new SimpleParamsBuilder<>(initialParams); } @Override public PojoUtils.FieldNaming getFieldNamingStrategy() { return PojoUtils.FieldNaming.CONSISTENT; } /** * The builder modifies initial model parameters directly by reflection. * * Usage: * <pre>{@code * GBMModel.GBMParameters params = * new SimpleParamsBuilder(initialParams) * .set("_ntrees", 30).set("_learn_rate", 0.01).build() * }</pre> * * @param <MP> type of model parameters object */ public static class SimpleParamsBuilder<MP extends Model.Parameters> implements ModelParametersBuilder<MP> { final private MP params; public SimpleParamsBuilder(MP initialParams) { params = initialParams; } @Override public ModelParametersBuilder<MP> set(String name, Object value) { PojoUtils.setField(params, name, value, PojoUtils.FieldNaming.CONSISTENT); return this; } @Override public MP build() { return params; } } } }