package hex.schemas;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import water.AutoBuffer;
import water.H2O;
import water.api.API;
import water.api.SpecifiesHttpResponseCode;
import water.api.schemas3.*;
import water.util.IcedSortedHashMap;
import water.util.ReflectionUtils;
import java.util.Properties;
public class ModelBuilderSchema<B extends ModelBuilder, S extends ModelBuilderSchema<B,S,P>, P extends
ModelParametersSchemaV3> extends RequestSchemaV3<B,S> implements SpecifiesHttpResponseCode {
// NOTE: currently ModelBuilderSchema has its own JSON serializer.
// If you add more fields here you MUST add them to writeJSON_impl() below.
public static class IcedHashMapStringModelBuilderSchema extends IcedSortedHashMap<String, ModelBuilderSchema> {}
// Input fields
@API(help="Model builder parameters.")
public P parameters;
// Output fields
@API(help="The algo name for this ModelBuilder.", direction=API.Direction.OUTPUT)
public String algo;
@API(help="The pretty algo name for this ModelBuilder (e.g., Generalized Linear Model, rather than GLM).", direction=API.Direction.OUTPUT)
public String algo_full_name;
@API(help="Model categories this ModelBuilder can build.", values={ "Unknown", "Binomial", "Multinomial", "Regression", "Clustering", "AutoEncoder", "DimReduction" }, direction = API.Direction.OUTPUT)
public ModelCategory[] can_build;
@API(help="Indicator whether the model is supervised or not.", direction=API.Direction.OUTPUT)
public boolean supervised;
@API(help="Should the builder always be visible, be marked as beta, or only visible if the user starts up with the experimental flag?", values = { "Experimental", "Beta", "AlwaysVisible" }, direction = API.Direction.OUTPUT)
public ModelBuilder.BuilderVisibility visibility;
@API(help = "Job Key", direction = API.Direction.OUTPUT)
public JobV3 job;
@API(help="Parameter validation messages", direction=API.Direction.OUTPUT)
public ValidationMessageV3 messages[];
@API(help="Count of parameter validation errors", direction=API.Direction.OUTPUT)
public int error_count;
@API(help="HTTP status to return for this build.", json = false, direction=API.Direction.OUTPUT)
public int __http_status; // The handler sets this to 400 if we're building and error_count > 0, else 200.
public ModelBuilderSchema() {
this.parameters = createParametersSchema();
}
public void setHttpStatus(int status) {
__http_status = status;
}
public int httpStatus() {
return __http_status;
}
/** Factory method to create the model-specific parameters schema. */
final public P createParametersSchema() {
// special case, because ModelBuilderSchema is the top of the tree and is parameterized differently
if (ModelBuilderSchema.class == this.getClass()) {
return (P)new ModelParametersSchemaV3();
}
try {
Class<? extends ModelParametersSchemaV3> parameters_class = ReflectionUtils.findActualClassParameter(this.getClass(), 2);
return (P)parameters_class.newInstance();
}
catch (Exception e) {
throw H2O.fail("Caught exception trying to instantiate a builder instance for ModelBuilderSchema: " + this + ": " + e, e);
}
}
public S fillFromParms(Properties parms) {
this.parameters.fillFromParms(parms);
return (S)this;
}
/** Create the corresponding impl object, as well as its parameters object. */
@Override final public B createImpl() {
return ModelBuilder.make(getSchemaType(), null, null);
}
@Override public B fillImpl(B impl) {
super.fillImpl(impl);
parameters.fillImpl(impl._parms);
impl.init(false); // validate parameters
return impl;
}
// Generic filling from the impl
@Override public S fillFromImpl(B builder) {
// DO NOT, because it can already be running: builder.init(false); // check params
this.algo = builder._parms.algoName().toLowerCase();
this.algo_full_name = builder._parms.fullName();
this.supervised = builder.isSupervised();
this.can_build = builder.can_build();
this.visibility = builder.builderVisibility();
job = builder._job == null ? null : new JobV3(builder._job);
// In general, you can ask about a builder in-progress, and the error
// message list can be growing - so you have to be prepared to read it
// racily. Common for Grid searches exploring with broken parameter
// choices.
final ModelBuilder.ValidationMessage[] msgs = builder._messages; // Racily growing; read only once
if( msgs != null ) {
this.messages = new ValidationMessageV3[msgs.length];
int i = 0;
for (ModelBuilder.ValidationMessage vm : msgs) {
if( vm != null ) this.messages[i++] = new ValidationMessageV3().fillFromImpl(vm); // TODO: version // Note: does default field_name mapping
}
// default fieldname hacks
ValidationMessageV3.mapValidationMessageFieldNames(this.messages, new String[]{"_train", "_valid"}, new
String[]{"training_frame", "validation_frame"});
}
this.error_count = builder.error_count();
parameters = createParametersSchema();
parameters.fillFromImpl(builder._parms);
parameters.model_id = builder.dest() == null ? null : new KeyV3.ModelKeyV3(builder.dest());
return (S)this;
}
// TODO: Drop this writeJSON_impl and use the default one.
// TODO: Pull out the help text & metadata into the ParameterSchema for the front-end to display.
public final AutoBuffer writeJSON_impl( AutoBuffer ab ) {
ab.putJSON("job", job);
ab.put1(',');
ab.putJSONStr("algo", algo);
ab.put1(',');
ab.putJSONStr("algo_full_name", algo_full_name);
ab.put1(',');
ab.putJSONAEnum("can_build", can_build);
ab.put1(',');
ab.putJSONEnum("visibility", visibility);
ab.put1(',');
ab.putJSONZ("supervised", supervised);
ab.put1(',');
ab.putJSONA("messages", messages);
ab.put1(',');
ab.putJSON4("error_count", error_count);
ab.put1(',');
// Builds ModelParameterSchemaV2 objects for each field, and then calls writeJSON on the array
ModelParametersSchemaV3.writeParametersJSON(ab, parameters, createParametersSchema().fillFromImpl((Model.Parameters)parameters.createImpl()));
return ab;
}
}