package water.api; import java.io.*; import java.net.URI; import java.util.*; import hex.Model; import hex.ModelMojoWriter; import hex.PartialDependence; import hex.genmodel.MojoModel; import water.*; import water.api.FramesHandler.Frames; import water.api.schemas3.*; import water.exceptions.*; import water.fvec.Frame; import water.persist.Persist; import water.util.FileUtils; import water.util.JCodeGen; public class ModelsHandler<I extends ModelsHandler.Models, S extends SchemaV3<I,S>> extends Handler { /** Class which contains the internal representation of the models list and params. */ public static final class Models extends Iced { public Key model_id; public Model[] models; public boolean find_compatible_frames = false; public static Model[] fetchAll() { final Key[] modelKeys = KeySnapshot.globalSnapshot().filter(new KeySnapshot.KVFilter() { @Override public boolean filter(KeySnapshot.KeyInfo k) { return Value.isSubclassOf(k._type, Model.class); } }).keys(); Model[] models = new Model[modelKeys.length]; for (int i = 0; i < modelKeys.length; i++) { Model model = getFromDKV("(none)", modelKeys[i]); models[i] = model; } return models; } /** * Fetch all the Frames so we can see if they are compatible with our Model(s). */ protected Map<Frame, Set<String>> fetchFrameCols() { if (!find_compatible_frames) return null; // caches for this request Frame[] all_frames = Frames.fetchAll(); Map<Frame, Set<String>> all_frames_cols = new HashMap<>(); for (Frame f : all_frames) all_frames_cols.put(f, new HashSet<>(Arrays.asList(f._names))); return all_frames_cols; } /** * For a given model return an array of the compatible frames. * * @param model The model to fetch the compatible frames for. * @param all_frames An array of all the Frames in the DKV. * @param all_frames_cols A Map of Frame to a Set of its column names. * @return */ private static Frame[] findCompatibleFrames(Model model, Frame[] all_frames, Map<Frame, Set<String>> all_frames_cols) { List<Frame> compatible_frames = new ArrayList<Frame>(); Set<String> model_column_names = new HashSet(Arrays.asList(model._output._names)); for (Map.Entry<Frame, Set<String>> entry : all_frames_cols.entrySet()) { Frame frame = entry.getKey(); Set<String> frame_cols = entry.getValue(); if (frame_cols.containsAll(model_column_names)) { // See if adapt throws an exception or not. try { if( model.adaptTestForTrain(new Frame(frame), false, false).length == 0 ) compatible_frames.add(frame); } catch( IllegalArgumentException e ) { // skip } } } return compatible_frames.toArray(new Frame[0]); } } /** Return all the models. */ @SuppressWarnings("unused") // called through reflection by RequestServer public ModelsV3 list(int version, ModelsV3 s) { Models m = s.createAndFillImpl(); m.models = Models.fetchAll(); return (ModelsV3) s.fillFromImplWithSynopsis(m); } // TODO: almost identical to ModelsHandler; refactor public static Model getFromDKV(String param_name, String key_str) { return getFromDKV(param_name, Key.make(key_str)); } // TODO: almost identical to ModelsHandler; refactor public static Model getFromDKV(String param_name, Key key) { if (key == null) throw new H2OIllegalArgumentException(param_name, "Models.getFromDKV()", null); Value v = DKV.get(key); if (v == null) throw new H2OKeyNotFoundArgumentException(param_name, key.toString()); Iced ice = v.get(); if (! (ice instanceof Model)) throw new H2OKeyWrongTypeArgumentException(param_name, key.toString(), Model.class, ice.getClass()); return (Model)ice; } /** Return a single model. */ @SuppressWarnings("unused") // called through reflection by RequestServer public StreamingSchema fetchPreview(int version, ModelsV3 s) { s.preview = true; return fetchJavaCode(version, s); } /** Return a single model. */ @SuppressWarnings("unused") // called through reflection by RequestServer public ModelsV3 fetch(int version, ModelsV3 s) { Model model = getFromDKV("key", s.model_id.key()); s.models = new ModelSchemaV3[1]; s.models[0] = (ModelSchemaV3)SchemaServer.schema(version, model).fillFromImpl(model); if (s.find_compatible_frames) { // TODO: refactor fetchFrameCols so we don't need this Models object Models m = new Models(); m.models = new Model[1]; m.models[0] = model; m.find_compatible_frames = true; Frame[] compatible = Models.findCompatibleFrames(model, Frames.fetchAll(), m.fetchFrameCols()); s.compatible_frames = new FrameV3[compatible.length]; // TODO: FrameBaseV3 ((ModelSchemaV3)s.models[0]).compatible_frames = new String[compatible.length]; int i = 0; for (Frame f : compatible) { s.compatible_frames[i] = new FrameV3(f).fillFromImpl(f); // TODO: FrameBaseV3 ((ModelSchemaV3)s.models[0]).compatible_frames[i] = f._key.toString(); i++; } } return s; } public StreamingSchema fetchJavaCode(int version, ModelsV3 s) { final Model model = getFromDKV("key", s.model_id.key()); final String filename = JCodeGen.toJavaId(s.model_id.key().toString()) + ".java"; // Return stream writer for given model return new StreamingSchema(model.new JavaModelStreamWriter(s.preview), filename); } @SuppressWarnings("unused") // called from the RequestServer through reflection public StreamingSchema fetchMojo(int version, ModelsV3 s) { Model model = getFromDKV("key", s.model_id.key()); String filename = JCodeGen.toJavaId(s.model_id.key().toString()) + ".zip"; return new StreamingSchema(model.getMojo(), filename); } @SuppressWarnings("unused") // called from the RequestServer through reflection public JobV3 makePartialDependence(int version, PartialDependenceV3 s) { PartialDependence partialDependence; if (s.destination_key != null) partialDependence = new PartialDependence(s.destination_key.key()); else partialDependence = new PartialDependence(Key.<PartialDependence>make()); s.fillImpl(partialDependence); //fill frame_id/model_id/nbins/etc. return new JobV3(partialDependence.execImpl()); } @SuppressWarnings("unused") // called from the RequestServer through reflection public PartialDependenceV3 fetchPartialDependence(int version, KeyV3.PartialDependenceKeyV3 s) { PartialDependence partialDependence = DKV.getGet(s.key()); return new PartialDependenceV3().fillFromImpl(partialDependence); } /** Remove an unlocked model. Fails if model is in-use. */ @SuppressWarnings("unused") // called through reflection by RequestServer public ModelsV3 delete(int version, ModelsV3 s) { Model model = getFromDKV("key", s.model_id.key()); model.delete(); // lock & remove return s; } /** * Remove ALL an unlocked models. Throws IAE for all deletes that failed * (perhaps because the Models were locked & in-use). */ @SuppressWarnings("unused") // called through reflection by RequestServer public ModelsV3 deleteAll(int version, ModelsV3 models) { final Key[] keys = KeySnapshot.globalKeysOfClass(Model.class); ArrayList<String> missing = new ArrayList<>(); Futures fs = new Futures(); for (Key key : keys) { try { getFromDKV("(none)", key).delete(null, fs); } catch (IllegalArgumentException iae) { missing.add(key.toString()); } } fs.blockForPending(); if( missing.size() != 0 ) throw new H2OKeysNotFoundArgumentException("(none)", missing.toArray(new String[missing.size()])); return models; } public ModelsV3 importModel(int version, ModelImportV3 mimport) { ModelsV3 s = Schema.newInstance(ModelsV3.class); try { URI targetUri = FileUtils.getURI(mimport.dir); Persist p = H2O.getPM().getPersistForURI(targetUri); InputStream is = p.open(targetUri.toString()); final AutoBuffer ab = new AutoBuffer(is); ab.sourceName = targetUri.toString(); Model model = (Model)Keyed.readAll(ab); s.models = new ModelSchemaV3[]{(ModelSchemaV3) SchemaServer.schema(version, model).fillFromImpl(model)}; } catch (FSIOException e) { throw new H2OIllegalArgumentException("dir", "importModel", mimport.dir); } return s; } public ModelExportV3 exportModel(int version, ModelExportV3 mexport) { Model model = getFromDKV("model_id", mexport.model_id.key()); try { URI targetUri = FileUtils.getURI(mexport.dir); // Really file, not dir Persist p = H2O.getPM().getPersistForURI(targetUri); OutputStream os = p.create(targetUri.toString(),mexport.force); model.writeAll(new AutoBuffer(os,true)).close(); // Send back mexport.dir = "file".equals(targetUri.getScheme()) ? new File(targetUri).getCanonicalPath() : targetUri.toString(); } catch (IOException e) { throw new H2OIllegalArgumentException("dir", "exportModel", e); } return mexport; } public ModelExportV3 exportMojo(int version, ModelExportV3 mexport) { Model model = getFromDKV("model_id", mexport.model_id.key()); try { URI targetUri = FileUtils.getURI(mexport.dir); // Really file, not dir Persist p = H2O.getPM().getPersistForURI(targetUri); OutputStream os = p.create(targetUri.toString(),mexport.force); ModelMojoWriter mojo = model.getMojo(); mojo.writeTo(os); // Send back mexport.dir = "file".equals(targetUri.getScheme()) ? new File(targetUri).getCanonicalPath() : targetUri.toString(); } catch (IOException e) { throw new H2OIllegalArgumentException("dir", "exportModel", e); } return mexport; } public ModelExportV3 exportModelDetails(int version, ModelExportV3 mexport){ Model model = getFromDKV("model_id", mexport.model_id.key()); try { URI targetUri = FileUtils.getURI(mexport.dir); // Really file, not dir Persist p = H2O.getPM().getPersistForURI(targetUri); //Make model schema before exporting ModelSchemaV3 modelSchema = (ModelSchemaV3)SchemaServer.schema(version, model).fillFromImpl(model); //Output model details to JSON OutputStream os = p.create(targetUri.toString(),mexport.force); os.write(modelSchema.writeJSON(new AutoBuffer()).buf()); // Send back mexport.dir = "file".equals(targetUri.getScheme()) ? new File(targetUri).getCanonicalPath() : targetUri.toString(); } catch (IOException e) { throw new H2OIllegalArgumentException("dir", "exportModelDetails", e); } return mexport; } }