package water.api;
import java.io.*;
import static water.util.FSUtils.isHdfs;
import static water.util.FSUtils.isS3N;
import java.io.File;
import java.io.IOException;
import hex.glm.GLMModel;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import water.*;
import water.persist.PersistHdfs;
import water.serial.Model2FileBinarySerializer;
import water.serial.Model2HDFSBinarySerializer;
import water.util.FSUtils;
import water.util.JCodeGen;
public class SaveModel extends Func {
static final int API_WEAVER = 1;
static public DocGen.FieldDoc[] DOC_FIELDS;
@API(help = "Model to save.", required=true, filter=Default.class)
Model model;
@API(help = "Path of directory to save model(s)", required = true, filter = Default.class, json=true, gridable = false)
String path;
@API(help="Overwrite existing files.", required = false, filter = Default.class, gridable = false)
boolean force = false;
@API(help="Save cross-validation models.", required = false, filter = Default.class, gridable = false)
boolean save_cv = true;
@Override
protected void execImpl() {
if (isHdfs(path) || isS3N(path)) saveToHdfs();
else saveToLocalFS();
}
private void saveToLocalFS() {
File parentDir = new File(path);
if (!force && parentDir.exists()) throw new IllegalArgumentException("The file " + path + " already exists!");
try {
// If force is specified then delete the file f
if (force && parentDir.exists()) delete(parentDir);
// Create folder
parentDir.mkdirs();
// Save parent model
new Model2FileBinarySerializer().save(model, new File(parentDir, JCodeGen.toJavaId(model._key.toString())));
// Write to model_names
File model_names = new File(parentDir, "model_names");
FileOutputStream is = new FileOutputStream(model_names);
OutputStreamWriter osw = new OutputStreamWriter(is);
BufferedWriter br = new BufferedWriter(osw);
br.write("main model : " + model._key.toString());
br.newLine();
// Save cross validation models
if (save_cv) {
Model[] models = getCrossValModels(model);
System.out.println(models);
for (Model m : models) {
new Model2FileBinarySerializer().save(m, new File(parentDir, JCodeGen.toJavaId(m._key.toString())));
br.write(JCodeGen.toJavaId(m._key.toString()));
br.newLine();
}
}
br.close();
} catch( IOException e ) {
throw new IllegalArgumentException("Cannot save file " + path, e);
}
}
private void saveToHdfs() {
if (FSUtils.isBareS3NBucketWithoutTrailingSlash(path)) { path += "/"; }
Path parentDir = new Path(path);
try {
FileSystem fs = FileSystem.get(parentDir.toUri(), PersistHdfs.CONF);
if (force && fs.exists(parentDir)) fs.delete(parentDir);
fs.mkdirs(parentDir);
// Save parent model
new Model2HDFSBinarySerializer(fs, force).save(model, new Path(parentDir, JCodeGen.toJavaId(model._key.toString())));
// Save parent model key to model_names file
Path model_names = new Path(parentDir, "model_names");
BufferedWriter br = new BufferedWriter(new OutputStreamWriter(fs.create(model_names,true)));
br.write("main model : " + model._key.toString());
br.newLine();
if (save_cv) {
Model[] models = getCrossValModels(model);
for (Model m : models ) {
new Model2HDFSBinarySerializer(fs, force).save(m, new Path(parentDir, JCodeGen.toJavaId(m._key.toString())));
br.write(JCodeGen.toJavaId(m._key.toString()));
br.newLine();
}
}
br.close();
} catch( IOException e ) {
throw new IllegalArgumentException("Cannot save file " + path, e);
}
}
@Override public boolean toHTML(StringBuilder sb) {
sb.append("<div class=\"alert alert-success\">")
.append("Model ")
.append(Inspector.link(model._key.toString(), model._key.toString()))
.append(" was sucessfuly saved to <b>"+path+"</b> file.");
sb.append("</div>");
return true;
}
private void delete(File f) {
if (!f.isDirectory()) {
f.delete();
} else {
File[] contents = f.listFiles();
for(File ef : contents){
delete(ef);
}
f.delete();
}
}
private Model[] getCrossValModels(Model m) {
Model[] models = null;
if (m instanceof GLMModel && ((GLMModel) m).xvalModels() == null ) {
models = NO_MODELS;
} else if (m instanceof GLMModel && ((GLMModel) m).xvalModels().length > 0) {
Key[] keys = ((GLMModel) m).xvalModels();
models = new Model[keys.length];
int i = 0;
for (Key k : keys) {
models[i++] = UKV.get(k);
}
} else {
if (m.hasCrossValModels()) {
Job.ValidatedJob j = (Job.ValidatedJob) m.get_params();
models = new Model[j.xval_models.length];
int i = 0;
for (Key k : j.xval_models) {
System.out.println(k);
models[i++] = UKV.get(k);
}
} else {
models = NO_MODELS;
}
}
assert models != null;
return models;
}
private static final Model[] NO_MODELS = new Model[] {};
}