package water.api;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelParametersBuilderFactory;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.grid.HyperSpaceSearchCriteria;
import hex.schemas.*;
import water.H2O;
import water.Job;
import water.Key;
import water.TypeMap;
import water.api.schemas3.JobV3;
import water.api.schemas3.ModelParametersSchemaV3;
import water.exceptions.H2OIllegalArgumentException;
import water.util.IcedHashMap;
import water.util.PojoUtils;
import java.lang.reflect.Field;
import java.util.*;
/**
* A generic grid search handler implementing launch of grid search.
*
* <p>A model specific grid search handlers should inherit from class and implements corresponding
* methods.
*
* FIXME: how to get rid of P, since it is already enforced by S
*
* @param <G> Implementation output of grid search
* @param <MP> Type of model parameters
* @param <P> Type of schema representing model parameters
* @param <S> Schema representing structure of grid search end-point
*/
public class GridSearchHandler<G extends Grid<MP>,
S extends GridSearchSchema<G, S, MP, P>,
MP extends Model.Parameters,
P extends ModelParametersSchemaV3> extends Handler {
// Invoke the handler with parameters. Can throw any exception the called handler can throw.
// TODO: why does this do its own params filling?
// TODO: why does this do its own sub-dispatch?
@Override
public S handle(int version, water.api.Route route, Properties parms, String postBody) throws Exception {
// Only here for train or validate-parms
if( !route._handler_method.getName().equals("train") )
throw water.H2O.unimpl();
// Peek out the desired algo from the URL
String ss[] = route._url.split("/");
String algoURLName = ss[3]; // {}/{99}/{Grid}/{gbm}/
String algoName = ModelBuilder.algoName(algoURLName); // gbm -> GBM; deeplearning -> DeepLearning
String schemaDir = ModelBuilder.schemaDirectory(algoURLName);
// Get the latest version of this algo: /99/Grid/gbm ==> GBMV3
// String algoSchemaName = SchemaServer.schemaClass(version, algoName).getSimpleName(); // GBMV3
// int algoVersion = Integer.valueOf(algoSchemaName.substring(algoSchemaName.lastIndexOf("V")+1)); // '3'
// Ok, i'm replacing one hack with another hack here, because SchemaServer.schema*() calls are getting eliminated.
// There probably shouldn't be any reference to algoVersion here at all... TODO: unhack all of this
int algoVersion = 3;
if (algoName.equals("SVD") || algoName.equals("Aggregator") || algoName.equals("StackedEnsemble")) algoVersion = 99;
// TODO: this is a horrible hack which is going to cause maintenance problems:
String paramSchemaName = schemaDir+algoName+"V"+algoVersion+"$"+ModelBuilder.paramName(algoURLName)+"V"+algoVersion;
// Build the Grid Search schema, and fill it from the parameters
S gss = (S) new GridSearchSchema();
gss.init_meta();
gss.parameters = (P)TypeMap.newFreezable(paramSchemaName);
gss.parameters.init_meta();
gss.hyper_parameters = new IcedHashMap<>();
// Get default parameters, then overlay the passed-in values
ModelBuilder builder = ModelBuilder.make(algoURLName,null,null); // Default parameter settings
gss.parameters.fillFromImpl(builder._parms); // Defaults for this builder into schema
gss.fillFromParms(parms); // Override defaults from user parms
// Verify list of hyper parameters
// Right now only names, no types
// note: still use _validation_frame and and _training_frame at this point.
// Do not change those names yet.
validateHyperParams((P)gss.parameters, gss.hyper_parameters);
// Get actual parameters
MP params = (MP) gss.parameters.createAndFillImpl();
Map<String,Object[]> sortedMap = new TreeMap<>(gss.hyper_parameters);
// Need to change validation_frame to valid now. HyperSpacewalker will complain
// if it encountered an illegal parameter name. From now on, validation_frame,
// training_fame are no longer valid names.
if (sortedMap.containsKey("validation_frame")) {
sortedMap.put("valid", sortedMap.get("validation_frame"));
sortedMap.remove("validation_frame");
}
// Get/create a grid for given frame
// FIXME: Grid ID is not pass to grid search builder!
Key<Grid> destKey = gss.grid_id != null ? gss.grid_id.key() : null;
// Create target grid search object (keep it private for now)
// Start grid search and return the schema back with job key
Job<Grid> gsJob = GridSearch.startGridSearch(destKey,
params,
sortedMap,
new DefaultModelParametersBuilderFactory<MP, P>(),
(HyperSpaceSearchCriteria)gss.search_criteria.createAndFillImpl());
// Fill schema with job parameters
// FIXME: right now we have to remove grid parameters which we sent back
gss.hyper_parameters = null;
gss.total_models = gsJob._result.get().getModelCount(); // TODO: looks like it's currently always 0
gss.job = new JobV3(gsJob);
return gss;
}
@SuppressWarnings("unused") // called through reflection by RequestServer
public S train(int version, S gridSearchSchema) { throw H2O.fail(); }
/**
* Validate given hyper parameters with respect to type parameter P.
*
* It verifies that given parameters are annotated in P with @API annotation
*
* @param params regular model build parameters
* @param hyperParams map of hyper parameters
*/
protected void validateHyperParams(P params, Map<String, Object[]> hyperParams) {
List<SchemaMetadata.FieldMetadata> fsMeta = SchemaMetadata.getFieldMetadata(params);
for (Map.Entry<String, Object[]> hparam : hyperParams.entrySet()) {
SchemaMetadata.FieldMetadata fieldMetadata = null;
// Found corresponding metadata about the field
for (SchemaMetadata.FieldMetadata fm : fsMeta) {
if (fm.name.equals(hparam.getKey())) {
fieldMetadata = fm;
break;
}
}
if (fieldMetadata == null) {
throw new H2OIllegalArgumentException(hparam.getKey(), "grid",
"Unknown hyper parameter for grid search!");
}
if (!fieldMetadata.is_gridable) {
throw new H2OIllegalArgumentException(hparam.getKey(), "grid",
"Illegal hyper parameter for grid search! The parameter '"
+ fieldMetadata.name + " is not gridable!");
}
}
}
public static class DefaultModelParametersBuilderFactory<MP extends Model.Parameters, PS extends ModelParametersSchemaV3>
implements ModelParametersBuilderFactory<MP> {
@Override
public ModelParametersBuilder<MP> get(MP initialParams) {
return new ModelParametersFromSchemaBuilder<MP, PS>(initialParams);
}
@Override
public PojoUtils.FieldNaming getFieldNamingStrategy() {
return PojoUtils.FieldNaming.DEST_HAS_UNDERSCORES;
}
}
/**
* Model parameters factory building model parameters with respect to its schema. <p> A user calls
* the {@link #set(String, Object)} method with names of parameters as they are defined in Schema.
* The builder transfer the given values from Schema to corresponding model parameters object.
* </p>
*
* @param <MP> type of model parameters
* @param <PS> type of schema representing model parameters
*/
public static class ModelParametersFromSchemaBuilder<MP extends Model.Parameters, PS extends ModelParametersSchemaV3>
implements ModelParametersBuilderFactory.ModelParametersBuilder<MP> {
final private MP params;
final private PS paramsSchema;
final private ArrayList<String> fields;
public ModelParametersFromSchemaBuilder(MP initialParams) {
params = initialParams;
paramsSchema = (PS) SchemaServer.schema(-1, params.getClass());
fields = new ArrayList<>(7);
}
public ModelParametersFromSchemaBuilder<MP, PS> set(String name, Object value) {
try {
Field f = paramsSchema.getClass().getField(name);
API api = (API) f.getAnnotations()[0];
Schema.setField(paramsSchema, f, name, value.toString(), api.required(),
paramsSchema.getClass());
fields.add(name);
} catch (NoSuchFieldException e) {
throw new IllegalArgumentException("Cannot find field '" + name + "'" + " to value " + value, e);
} catch (IllegalAccessException e) {
throw new IllegalArgumentException("Cannot set field '" + name + "'" + " to value " + value, e);
} catch (RuntimeException e) {
throw new IllegalArgumentException("Cannot set field '" + name + "'" + " to value" + value, e);
}
return this;
}
public MP build() {
PojoUtils
.copyProperties(params, paramsSchema, PojoUtils.FieldNaming.DEST_HAS_UNDERSCORES, null,
fields.toArray(new String[fields.size()]));
// FIXME: handle these train/valid fields in different way
// See: ModelParametersSchemaV3#fillImpl
if (params._valid == null && paramsSchema.validation_frame != null) {
params._valid = Key.make(paramsSchema.validation_frame.name);
}
if (params._train == null && paramsSchema.training_frame != null) {
params._train = Key.make(paramsSchema.training_frame.name);
}
return params;
}
}
}