package water.api;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelMojoWriter;
import hex.schemas.ModelBuilderSchema;
import water.H2O;
import water.Iced;
import water.api.schemas3.ModelBuildersV3;
import water.api.schemas3.SchemaV3;
import water.api.schemas4.ListRequestV4;
import water.api.schemas4.ModelInfoV4;
import water.api.schemas4.ModelsInfoV4;
import water.util.ReflectionUtils;
import java.lang.reflect.Method;
class ModelBuildersHandler extends Handler {
/** Return all the modelbuilders. */
@SuppressWarnings("unused") // called through reflection by RequestServer
public ModelBuildersV3 list(int version, ModelBuildersV3 m) {
m.model_builders = new ModelBuilderSchema.IcedHashMapStringModelBuilderSchema();
for( String algo : ModelBuilder.algos() ) {
ModelBuilder builder = ModelBuilder.make(algo, null, null);
m.model_builders.put(algo.toLowerCase(), (ModelBuilderSchema)SchemaServer.schema(version, builder).fillFromImpl(builder));
}
return m;
}
/** Return a single modelbuilder. */
@SuppressWarnings("unused") // called through reflection by RequestServer
public ModelBuildersV3 fetch(int version, ModelBuildersV3 m) {
m.model_builders = new ModelBuilderSchema.IcedHashMapStringModelBuilderSchema();
ModelBuilder builder = ModelBuilder.make(m.algo, null, null);
m.model_builders.put(m.algo.toLowerCase(), (ModelBuilderSchema)SchemaServer.schema(version, builder).fillFromImpl(builder));
return m;
}
public static class ModelIdV3 extends SchemaV3<Iced, ModelIdV3> {
@API(help="Model ID", direction = API.Direction.OUTPUT)
String model_id;
}
/** Calculate next unique model_id. */
@SuppressWarnings("unused") // called through reflection by RequestServer
public ModelIdV3 calcModelId(int version, ModelBuildersV3 m) {
m.model_builders = new ModelBuilderSchema.IcedHashMapStringModelBuilderSchema();
String model_id = H2O.calcNextUniqueModelId(m.algo);
ModelIdV3 mm = new ModelIdV3();
mm.model_id = model_id;
return mm;
}
@SuppressWarnings("unused") // called through reflection by RequestServer
public ModelsInfoV4 modelsInfo(int version, ListRequestV4 m) {
String[] algos = ModelBuilder.algos();
ModelInfoV4[] infos = new ModelInfoV4[algos.length];
ModelsInfoV4 res = new ModelsInfoV4();
for (int i = 0; i < algos.length; i++) {
ModelBuilder builder = ModelBuilder.make(algos[i], null, null);
infos[i] = new ModelInfoV4();
infos[i].algo = algos[i];
infos[i].maturity = builder.builderVisibility() == ModelBuilder.BuilderVisibility.Stable? "stable" :
builder.builderVisibility() == ModelBuilder.BuilderVisibility.Beta? "beta" : "alpha";
infos[i].have_mojo = builder.haveMojo();
infos[i].have_pojo = builder.havePojo();
infos[i].mojo_version = infos[i].have_mojo? detectMojoVersion(builder) : null;
}
res.models = infos;
return res;
}
private String detectMojoVersion(ModelBuilder builder) {
Class<? extends Model> modelClass = ReflectionUtils.findActualClassParameter(builder.getClass(), 0);
try {
Method getMojoMethod = modelClass.getDeclaredMethod("getMojo");
Class<?> retClass = getMojoMethod.getReturnType();
if (retClass == ModelMojoWriter.class || !ModelMojoWriter.class.isAssignableFrom(retClass))
throw new RuntimeException("Method getMojo() in " + modelClass + " must return the concrete implementation " +
"of the ModelMojoWriter class. The return type is declared as " + retClass);
try {
ModelMojoWriter mmw = (ModelMojoWriter) retClass.newInstance();
return mmw.mojoVersion();
} catch (InstantiationException e) {
throw getMissingCtorException(retClass, e);
} catch (IllegalAccessException e) {
throw getMissingCtorException(retClass, e);
}
} catch (NoSuchMethodException e) {
throw new RuntimeException("Model class " + modelClass + " is expected to have method getMojo();");
}
}
private RuntimeException getMissingCtorException(Class<?> retClass, Exception e) {
return new RuntimeException("MojoWriter class " + retClass + " must define a no-arg constructor.\n" + e);
}
}