package water.api;
import dontweave.gson.Gson;
import dontweave.gson.GsonBuilder;
import dontweave.gson.JsonElement;
import dontweave.gson.JsonObject;
import hex.VarImp;
import hex.deeplearning.DeepLearning;
import hex.drf.DRF;
import hex.gbm.GBM;
import hex.glm.GLM2;
import hex.glm.GLMModel;
import hex.singlenoderf.SpeeDRF;
import hex.nb.NaiveBayes;
import hex.nb.NBModel;
import org.apache.commons.math3.util.Pair;
import water.*;
import water.api.Frames.FrameSummary;
import water.fvec.Frame;
import java.util.*;
import static water.util.ParamUtils.*;
public class Models extends Request2 {
///////////////////////
// Request2 boilerplate
///////////////////////
static final int API_WEAVER=1; // This file has auto-gen'd doc & json fields
static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code.
// This Request supports the HTML 'GET' command, and this is the help text
// for GET.
static final String DOC_GET = "Return the list of models.";
public static String link(Key k, String content){
return "<a href='/2/Models'>" + content + "</a>";
}
////////////////
// Query params:
////////////////
@API(help="An existing H2O Model key.", required=false, filter=Default.class)
Model key = null;
@API(help="Find Frames that are compatible with the Model.", required=false, filter=Default.class)
boolean find_compatible_frames = false;
@API(help="An existing H2O Frame key to score with the Model which is specified by the key parameter.", required=false, filter=Default.class)
Frame score_frame = null;
@API(help="Should we adapt() the Frame to the Model?", required=false, filter=Default.class)
boolean adapt = true;
/////////////////
// The Code (tm):
/////////////////
public static final Gson gson = new GsonBuilder().serializeSpecialFloatingPointValues().setPrettyPrinting().create();
public static final class ModelSummary {
public String[] warnings = new String[0];
public String model_algorithm = "unknown";
public Model.ModelCategory model_category = Model.ModelCategory.Unknown;
public Job.JobState state = Job.JobState.CREATED;
public String id = null;
public String key = null;
public long creation_epoch_time_millis = -1;
public long training_duration_in_ms = -1;
public List<String> input_column_names = new ArrayList<String>();
public String response_column_name = "unknown";
public Map critical_parameters = new HashMap<String, Object>();
public Map secondary_parameters = new HashMap<String, Object>();
public Map expert_parameters = new HashMap<String, Object>();
public Map variable_importances = null;
public Set<String> compatible_frames = new HashSet<String>();
}
private static Map whitelistJsonObject(JsonObject unfiltered, Set<String> whitelist) {
// If we create a new JsonObject here and serialize it the key/value pairs are inside
// a superflouous "members" object, so create a Map instead.
JsonObject filtered = new JsonObject();
Set<Map.Entry<String,JsonElement>> entries = unfiltered.entrySet();
for (Map.Entry<String,JsonElement> entry : entries) {
String key = entry.getKey();
if (whitelist.contains(key))
filtered.add(key, entry.getValue());
}
return gson.fromJson(gson.toJson(filtered), Map.class);
}
/**
* Fetch all the Frames so we can see if they are compatible with our Model(s).
*/
private Pair<Map<String, Frame>, Map<String, Set<String>>> fetchFrames() {
Map<String, Frame> all_frames = null;
Map<String, Set<String>> all_frames_cols = null;
if (this.find_compatible_frames) {
// caches for this request
all_frames = Frames.fetchAll();
all_frames_cols = new TreeMap<String, Set<String>>();
for (Map.Entry<String, Frame> entry : all_frames.entrySet()) {
all_frames_cols.put(entry.getKey(), new TreeSet<String>(Arrays.asList(entry.getValue()._names)));
}
}
return new Pair<Map<String, Frame>, Map<String, Set<String>>>(all_frames, all_frames_cols);
}
private static Map<String, Frame> findCompatibleFrames(Model model, Map<String, Frame> all_frames, Map<String, Set<String>> all_frames_cols) {
Map<String, Frame> compatible_frames = new TreeMap<String, Frame>();
Set<String> model_column_names = new HashSet(Arrays.asList(model._names));
for (Map.Entry<String, Set<String>> entry : all_frames_cols.entrySet()) {
Set<String> frame_cols = entry.getValue();
if (frame_cols.containsAll(model_column_names)) {
/// See if adapt throws an exception or not.
try {
Frame frame = all_frames.get(entry.getKey());
Frame[] outputs = model.adapt(frame, false); // TODO: this does too much work; write canAdapt()
Frame adapted = outputs[0];
Frame trash = outputs[1];
// adapted.delete(); // TODO: shouldn't we clean up adapted vecs? But we can't delete() the frame as a whole. . .
trash.delete();
// A-Ok
compatible_frames.put(entry.getKey(), frame);
}
catch (Exception e) {
// skip
}
}
}
return compatible_frames;
}
public static Map<String, ModelSummary> generateModelSummaries(Set<String>keys, Map<String, Model> models, boolean find_compatible_frames, Map<String, Frame> all_frames, Map<String, Set<String>> all_frames_cols) {
Map<String, ModelSummary> modelSummaries = new TreeMap<String, ModelSummary>();
if (null == keys) {
keys = models.keySet();
}
for (String key : keys) {
ModelSummary summary = new ModelSummary();
Models.summarizeAndEnhanceModel(summary, models.get(key), find_compatible_frames, all_frames, all_frames_cols);
modelSummaries.put(key, summary);
}
return modelSummaries;
}
/**
* Summarize subclasses of water.Model.
*/
protected static void summarizeAndEnhanceModel(ModelSummary summary, Model model, boolean find_compatible_frames, Map<String, Frame> all_frames, Map<String, Set<String>> all_frames_cols) {
if (model instanceof GLMModel) {
summarizeGLMModel(summary, (GLMModel) model);
} else if (model instanceof DRF.DRFModel) {
summarizeDRFModel(summary, (DRF.DRFModel) model);
} else if (model instanceof hex.deeplearning.DeepLearningModel) {
summarizeDeepLearningModel(summary, (hex.deeplearning.DeepLearningModel) model);
} else if (model instanceof hex.gbm.GBM.GBMModel) {
summarizeGBMModel(summary, (hex.gbm.GBM.GBMModel) model);
} else if (model instanceof hex.singlenoderf.SpeeDRFModel) {
summarizeSpeeDRFModel(summary, (hex.singlenoderf.SpeeDRFModel) model);
} else if (model instanceof NBModel) {
summarizeNBModel(summary, (NBModel) model);
} else {
// catch-all
summarizeModelCommonFields(summary, model);
}
if (find_compatible_frames) {
Map<String, Frame> compatible_frames = findCompatibleFrames(model, all_frames, all_frames_cols);
summary.compatible_frames = compatible_frames.keySet();
}
}
/**
* Summarize fields which are generic to water.Model.
*/
private static void summarizeModelCommonFields(ModelSummary summary, Model model) {
String[] names = model._names;
summary.warnings = model.warnings;
summary.model_algorithm = model.getClass().toString(); // fallback only
// model.job() is a local copy; on multinode clusters we need to get from the DKV
Key job_key = ((Job)model.job()).self();
if (null == job_key) throw H2O.fail("Null job key for model: " + (model == null ? "null model" : model._key)); // later when we deserialize models from disk we'll relax this constraint
Job job = DKV.get(job_key).get();
summary.state = job.getState();
summary.model_category = model.getModelCategory();
UniqueId unique_id = model.getUniqueId();
summary.id = unique_id.getId();
summary.key = unique_id.getKey();
summary.creation_epoch_time_millis = unique_id.getCreationEpochTimeMillis();
summary.training_duration_in_ms = model.training_duration_in_ms;
summary.response_column_name = names[names.length - 1];
for (int i = 0; i < names.length - 1; i++)
summary.input_column_names.add(names[i]);
// Ugh.
VarImp vi = model.varimp();
if (null != vi) {
summary.variable_importances = new LinkedHashMap();
summary.variable_importances.put("varimp", vi.varimp);
summary.variable_importances.put("variables", vi.getVariables());
summary.variable_importances.put("method", vi.method);
summary.variable_importances.put("max_var", vi.max_var);
summary.variable_importances.put("scaled", vi.scaled());
}
}
/******
* GLM2
******/
private static final Set<String> GLM_critical_params = getCriticalParamNames(GLM2.DOC_FIELDS);
private static final Set<String> GLM_secondary_params = getSecondaryParamNames(GLM2.DOC_FIELDS);
private static final Set<String> GLM_expert_params = getExpertParamNames(GLM2.DOC_FIELDS);
/**
* Summarize fields which are specific to hex.glm.GLMModel.
*/
private static void summarizeGLMModel(ModelSummary summary, hex.glm.GLMModel model) {
// add generic fields such as column names
summarizeModelCommonFields(summary, model);
summary.model_algorithm = "GLM";
JsonObject all_params = (model.get_params()).toJSON();
summary.critical_parameters = whitelistJsonObject(all_params, GLM_critical_params);
summary.secondary_parameters = whitelistJsonObject(all_params, GLM_secondary_params);
summary.expert_parameters = whitelistJsonObject(all_params, GLM_expert_params);
}
/******
* DRF
******/
private static final Set<String> DRF_critical_params = getCriticalParamNames(DRF.DOC_FIELDS);
private static final Set<String> DRF_secondary_params = getSecondaryParamNames(DRF.DOC_FIELDS);
private static final Set<String> DRF_expert_params = getExpertParamNames(DRF.DOC_FIELDS);
/**
* Summarize fields which are specific to hex.drf.DRF.DRFModel.
*/
private static void summarizeDRFModel(ModelSummary summary, hex.drf.DRF.DRFModel model) {
// add generic fields such as column names
summarizeModelCommonFields(summary, model);
summary.model_algorithm = "BigData RF";
JsonObject all_params = (model.get_params()).toJSON();
summary.critical_parameters = whitelistJsonObject(all_params, DRF_critical_params);
summary.secondary_parameters = whitelistJsonObject(all_params, DRF_secondary_params);
summary.expert_parameters = whitelistJsonObject(all_params, DRF_expert_params);
}
/******
* SpeeDRF
******/
private static final Set<String> SpeeDRF_critical_params = getCriticalParamNames(SpeeDRF.DOC_FIELDS);
private static final Set<String> SpeeDRF_secondary_params = getSecondaryParamNames(SpeeDRF.DOC_FIELDS);
private static final Set<String> SpeeDRF_expert_params = getExpertParamNames(SpeeDRF.DOC_FIELDS);
/**
* Summarize fields which are specific to hex.drf.DRF.SpeeDRFModel.
*/
private static void summarizeSpeeDRFModel(ModelSummary summary, hex.singlenoderf.SpeeDRFModel model) {
// add generic fields such as column names
summarizeModelCommonFields(summary, model);
summary.model_algorithm = "Random Forest";
JsonObject all_params = (model.get_params()).toJSON();
summary.critical_parameters = whitelistJsonObject(all_params, SpeeDRF_critical_params);
summary.secondary_parameters = whitelistJsonObject(all_params, SpeeDRF_secondary_params);
summary.expert_parameters = whitelistJsonObject(all_params, SpeeDRF_expert_params);
}
/***************
* DeepLearning
***************/
private static final Set<String> DL_critical_params = getCriticalParamNames(DeepLearning.DOC_FIELDS);
private static final Set<String> DL_secondary_params = getSecondaryParamNames(DeepLearning.DOC_FIELDS);
private static final Set<String> DL_expert_params =getExpertParamNames(DeepLearning.DOC_FIELDS);
/**
* Summarize fields which are specific to hex.deeplearning.DeepLearningModel.
*/
private static void summarizeDeepLearningModel(ModelSummary summary, hex.deeplearning.DeepLearningModel model) {
// add generic fields such as column names
summarizeModelCommonFields(summary, model);
summary.model_algorithm = "DeepLearning";
JsonObject all_params = (model.get_params()).toJSON();
summary.critical_parameters = whitelistJsonObject(all_params, DL_critical_params);
summary.secondary_parameters = whitelistJsonObject(all_params, DL_secondary_params);
summary.expert_parameters = whitelistJsonObject(all_params, DL_expert_params);
}
/******
* GBM
******/
private static final Set<String> GBM_critical_params = getCriticalParamNames(GBM.DOC_FIELDS);
private static final Set<String> GBM_secondary_params = getSecondaryParamNames(GBM.DOC_FIELDS);
private static final Set<String> GBM_expert_params = getExpertParamNames(GBM.DOC_FIELDS);
/**
* Summarize fields which are specific to hex.gbm.GBM.GBMModel.
*/
private static void summarizeGBMModel(ModelSummary summary, hex.gbm.GBM.GBMModel model) {
// add generic fields such as column names
summarizeModelCommonFields(summary, model);
summary.model_algorithm = "GBM";
JsonObject all_params = (model.get_params()).toJSON();
summary.critical_parameters = whitelistJsonObject(all_params, GBM_critical_params);
summary.secondary_parameters = whitelistJsonObject(all_params, GBM_secondary_params);
summary.expert_parameters = whitelistJsonObject(all_params, GBM_expert_params);
}
/******
* NB
******/
private static final Set<String> NB_critical_params = getCriticalParamNames(NaiveBayes.DOC_FIELDS);
private static final Set<String> NB_secondary_params = getSecondaryParamNames(NaiveBayes.DOC_FIELDS);
private static final Set<String> NB_expert_params = getExpertParamNames(NaiveBayes.DOC_FIELDS);
/**
* Summarize fields which are specific to hex.nb.NBModel.
*/
private static void summarizeNBModel(ModelSummary summary, hex.nb.NBModel model) {
// add generic fields such as column names
summarizeModelCommonFields(summary, model);
summary.model_algorithm = "Naive Bayes";
JsonObject all_params = (model.get_params()).toJSON();
summary.critical_parameters = whitelistJsonObject(all_params, NB_critical_params);
summary.secondary_parameters = whitelistJsonObject(all_params, NB_secondary_params);
summary.expert_parameters = whitelistJsonObject(all_params, NB_expert_params);
}
/**
* Fetch all Models from the KV store.
*/
protected Map<String, Model> fetchAll() {
return H2O.KeySnapshot.globalSnapshot().fetchAll(water.Model.class);
}
/**
* Score a frame with the given model.
*/
protected static Response scoreOne(Frame frame, Model score_model, boolean adapt) {
return Frames.scoreOne(frame, score_model);
}
/**
* Fetch all the Models from the KV store, sumamrize and enhance them, and return a map of them.
*/
private Response serveOneOrAll(Map<String, Model> modelsMap) {
// returns empty sets if !this.find_compatible_frames
Pair<Map<String, Frame>, Map<String, Set<String>>> frames_info = fetchFrames();
Map<String, Frame> all_frames = frames_info.getFirst();
Map<String, Set<String>> all_frames_cols = frames_info.getSecond();
Map<String, ModelSummary> modelSummaries = Models.generateModelSummaries(null, modelsMap, find_compatible_frames, all_frames, all_frames_cols);
Map resultsMap = new LinkedHashMap();
resultsMap.put("models", modelSummaries);
// If find_compatible_frames then include a map of the Frame summaries. Should we put this on a separate switch?
if (this.find_compatible_frames) {
Set<String> all_referenced_frames = new TreeSet<String>();
for (Map.Entry<String, ModelSummary> entry: modelSummaries.entrySet()) {
ModelSummary summary = entry.getValue();
all_referenced_frames.addAll(summary.compatible_frames);
}
Map<String, FrameSummary> frameSummaries = Frames.generateFrameSummaries(all_referenced_frames, all_frames, false, null, null);
resultsMap.put("frames", frameSummaries);
}
// TODO: temporary hack to get things going
String json = gson.toJson(resultsMap);
JsonObject result = gson.fromJson(json, JsonElement.class).getAsJsonObject();
return Response.done(result);
}
@Override
protected Response serve() {
if (null == this.key) {
return serveOneOrAll(fetchAll());
} else {
if (null == this.score_frame) {
Model model = this.key;
Map<String, Model> modelsMap = new TreeMap(); // Sort for pretty display and reliable ordering.
modelsMap.put(model._key.toString(), model);
return serveOneOrAll(modelsMap);
} else {
return scoreOne(this.score_frame, this.key, this.adapt);
}
}
} // serve()
}