package hex.schemas; import hex.Model; import hex.ModelMetrics; import hex.grid.Grid; import water.DKV; import water.Key; import water.api.*; import water.api.schemas3.*; import water.exceptions.H2OIllegalArgumentException; import water.util.TwoDimTable; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Set; /** * REST endpoint representing single grid object. * * FIXME: Grid should contain also grid definition - model parameters and definition of hyper parameters. */ public class GridSchemaV99 extends SchemaV3<Grid, GridSchemaV99> { // // Inputs // @API(help = "Grid id") public KeyV3.GridKeyV3 grid_id; @API(help = "Model performance metric to sort by. Examples: logloss, residual_deviance, mse, rmse, mae,rmsle, auc, r2, f1, recall, precision, accuracy, mcc, err, err_count, lift_top_group, max_per_class_error", required = false, direction = API.Direction.INOUT) public String sort_by; @API(help = "Specify whether sort order should be decreasing.", required = false, direction = API.Direction.INOUT) public boolean decreasing; // // Outputs // @API(help = "Model IDs built by a grid search") public KeyV3.ModelKeyV3[] model_ids; @API(help = "Used hyper parameters.", direction = API.Direction.OUTPUT) public String[] hyper_names; @API(help = "List of failed parameters", direction = API.Direction.OUTPUT) public ModelParametersSchemaV3[] failed_params; // Using common ancestor of XXXParamsV3 @API(help = "List of detailed failure messages", direction = API.Direction.OUTPUT) public String[] failure_details; @API(help = "List of detailed failure stack traces", direction = API.Direction.OUTPUT) public String[] failure_stack_traces; @API(help = "List of raw parameters causing model building failure", direction = API.Direction.OUTPUT) public String[][] failed_raw_params; @API(help = "Training model metrics for the returned models; only returned if sort_by is set", direction = API.Direction.OUTPUT) public ModelMetricsBaseV3[] training_metrics; @API(help = "Validation model metrics for the returned models; only returned if sort_by is set", direction = API.Direction.OUTPUT) public ModelMetricsBaseV3[] validation_metrics; @API(help = "Cross validation model metrics for the returned models; only returned if sort_by is set", direction = API.Direction.OUTPUT) public ModelMetricsBaseV3[] cross_validation_metrics; @API(help = "Cross validation model metrics summary for the returned models; only returned if sort_by is set", direction = API.Direction.OUTPUT) public TwoDimTableV3[] cross_validation_metrics_summary; @API(help="Summary", direction=API.Direction.OUTPUT) TwoDimTableV3 summary_table; @API(help="Scoring history", direction=API.Direction.OUTPUT, level=API.Level.secondary) TwoDimTableV3 scoring_history; @Override public Grid createImpl() { return Grid.GRID_PROTO; } @Override public GridSchemaV99 fillFromImpl(Grid grid) { Key<Model>[] gridModelKeys = grid.getModelKeys(); // Return only keys which are referencing to existing objects in DKV // However, here is still implicit race, since we are sending // keys to client, but referenced models can be deleted in meantime // Hence, client has to be responsible for handling this situation // - call getModel and check for null model List<Key<Model>> modelKeys = new ArrayList<>(gridModelKeys.length); // pre-allocate for (Key k : gridModelKeys) { if (k != null && DKV.get(k) != null) { modelKeys.add(k); } } // Default sort order -- TODO: Outsource if (sort_by == null && modelKeys.size() > 0 && modelKeys.get(0) != null) { Model m = DKV.getGet(modelKeys.get(0)); Model.GridSortBy sortBy = m != null ? m.getDefaultGridSortBy() : null; if (sortBy != null) { sort_by = sortBy._name; decreasing = sortBy._decreasing; } } // Check that we have a valid metric // If not, show all possible metrics if (modelKeys.size() > 0 && sort_by != null) { Set<String> possibleMetrics = ModelMetrics.getAllowedMetrics(modelKeys.get(0)); if (!possibleMetrics.contains(sort_by.toLowerCase())) { throw new H2OIllegalArgumentException("Invalid argument for sort_by specified. Must be one of: " + Arrays.toString(possibleMetrics.toArray(new String[0]))); } } // Are we sorting by model metrics? if (null != sort_by && ! sort_by.isEmpty()) { // sort the model keys modelKeys = ModelMetrics.sortModelsByMetric(sort_by, decreasing, modelKeys); // fill the metrics arrays training_metrics = new ModelMetricsBaseV3[modelKeys.size()]; validation_metrics = new ModelMetricsBaseV3[modelKeys.size()]; cross_validation_metrics = new ModelMetricsBaseV3[modelKeys.size()]; cross_validation_metrics_summary = new TwoDimTableV3[modelKeys.size()]; for (int i = 0; i < modelKeys.size(); i++) { Model m = DKV.getGet(modelKeys.get(i)); if (m != null) { Model.Output o = m._output; if (null != o._training_metrics) training_metrics[i] = (ModelMetricsBaseV3) SchemaServer.schema(3, o._training_metrics).fillFromImpl(o ._training_metrics); if (null != o._validation_metrics) validation_metrics[i] = (ModelMetricsBaseV3) SchemaServer.schema(3, o ._validation_metrics).fillFromImpl(o._validation_metrics); if (null != o._cross_validation_metrics) cross_validation_metrics[i] = (ModelMetricsBaseV3) SchemaServer .schema(3, o._cross_validation_metrics).fillFromImpl(o._cross_validation_metrics); if (o._cross_validation_metrics_summary != null) cross_validation_metrics_summary[i] = new TwoDimTableV3(o._cross_validation_metrics_summary); } } } KeyV3.ModelKeyV3[] modelIds = new KeyV3.ModelKeyV3[modelKeys.size()]; Key<Model>[] keys = new Key[modelKeys.size()]; for (int i = 0; i < modelIds.length; i++) { modelIds[i] = new KeyV3.ModelKeyV3(modelKeys.get(i)); keys[i] = modelIds[i].key(); } grid_id = new KeyV3.GridKeyV3(grid._key); model_ids = modelIds; hyper_names = grid.getHyperNames(); failed_params = toModelParametersSchema(grid.getFailedParameters()); failure_details = grid.getFailureDetails(); failure_stack_traces = grid.getFailureStackTraces(); failed_raw_params = grid.getFailedRawParameters(); TwoDimTable t = grid.createSummaryTable(keys, sort_by, decreasing); if (t!=null) summary_table = new TwoDimTableV3().fillFromImpl(t); TwoDimTable h = grid.createScoringHistoryTable(); if (h != null) scoring_history = new TwoDimTableV3().fillFromImpl(h); return this; } private ModelParametersSchemaV3[] toModelParametersSchema(Model.Parameters[] modelParameters) { if (modelParameters==null) return null; ModelParametersSchemaV3[] result = new ModelParametersSchemaV3[modelParameters.length]; for (int i = 0; i < modelParameters.length; i++) { if (modelParameters[i] != null) { result[i] = (ModelParametersSchemaV3) SchemaServer.schema(SchemaServer.getLatestVersion(), modelParameters[i]) .fillFromImpl(modelParameters[i]); } else { result[i] = null; } } return result; } }