package hex;
import hex.KMeans2.KMeans2Model;
import hex.KMeans2.KMeans2ModelView;
import hex.NeuralNet.NeuralNetModel;
import hex.drf.DRF.DRFModel;
import hex.gbm.GBM.GBMModel;
import hex.deeplearning.DeepLearningModel;
import hex.singlenoderf.SpeeDRFModel;
import hex.singlenoderf.SpeeDRFModelView;
import water.*;
import water.api.*;
import water.util.Utils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
public class GridSearch extends Job {
public Job[] jobs;
public GridSearch(){
}
@Override protected void execImpl() {
UKV.put(destination_key, this);
int max = jobs[0].gridParallelism();
int head = 0, tail = 0;
while( head < jobs.length && isRunning(self()) ) {
if( tail - head < max && tail < jobs.length )
jobs[tail++].fork();
else {
try {
jobs[head++].get();
} catch( Exception e ) {
throw new RuntimeException(e);
}
}
}
}
@Override protected void onCancelled() {
for( Job job : jobs )
job.cancel();
}
@Override public float progress() {
double d = 0.1;
for( Job job : jobs )
if(job.start_time > 0)
d += job.progress();
return Math.min(1f, (float) (d / jobs.length));
}
@Override public Response redirect() {
String redirectName = new GridSearchProgress().href();
return Response.redirect(this, redirectName, "job_key", job_key, "destination_key", destination_key);
}
public static class GridSearchProgress extends Progress2 {
static final int API_WEAVER = 1;
static public DocGen.FieldDoc[] DOC_FIELDS;
@API(help = "Jobs")
public Job[] jobs;
@API(help = "Prediction Errors")
public double[] prediction_errors;
@API(help = "State")
public String[] job_state;
@Override protected Response serve() {
Response response = super.serve();
if( destination_key != null ) {
GridSearch grid = UKV.get(destination_key);
if( grid != null )
jobs = grid.jobs;
updateErrors(null);
}
return response;
}
void updateErrors(ArrayList<JobInfo> infos) {
if (jobs == null) return;
prediction_errors = new double[jobs.length];
job_state = new String[jobs.length];
int i = 0;
for( Job job : jobs ) {
JobInfo info = new JobInfo();
info._job = job;
if(job.dest() != null){
Object value = UKV.get(job.dest());
info._model = value instanceof Model ? (Model) value : null;
if( info._model != null ) {
info._cm = info._model.cm();
info._error = info._model.mse();
}
}
if( info._cm != null && (info._model == null || info._model.isClassifier()))
info._error = info._cm.err();
if (infos != null) infos.add(info);
prediction_errors[i] = info._error;
job_state[i] = info._job.state.toString();
i++;
}
}
@Override public boolean toHTML(StringBuilder sb) {
if( jobs != null ) {
DocGen.HTML.arrayHead(sb);
sb.append("<tr class='warning'>");
ArrayList<Argument> args = jobs[0].arguments();
// Filter some keys to simplify UI
args = (ArrayList<Argument>) args.clone();
filter(args, "destination_key", "source", "cols", "ignored_cols", "ignored_cols_by_name", //
"response", "classification", "validation");
for (Argument arg : args) sb.append("<td><b>").append(arg._name).append("</b></td>");
sb.append("<td><b>").append("run time").append("</b></td>");
String perf = jobs[0].speedDescription();
if( perf != null )
sb.append("<td><b>").append(perf).append("</b></td>");
sb.append("<td><b>").append("model key").append("</b></td>");
sb.append("<td><b>").append("prediction error").append("</b></td>");
sb.append("<td><b>").append("F1 score").append("</b></td>");
sb.append("</tr>");
ArrayList<JobInfo> infos = new ArrayList<JobInfo>();
updateErrors(infos);
Collections.sort(infos, new Comparator<JobInfo>() {
@Override public int compare(JobInfo a, JobInfo b) {
return Double.compare(a._error, b._error);
}
});
for( JobInfo info : infos ) {
sb.append("<tr>");
for( Argument a : args ) {
try {
Object value = a._field.get(info._job);
String s;
if( value instanceof int[] )
s = Utils.sampleToString((int[]) value, 20);
else if( value instanceof double[] )
s = Utils.sampleToString((double[]) value, 20);
else
s = "" + value;
sb.append("<td>").append(s).append("</td>");
} catch( Exception e ) {
throw new RuntimeException(e);
}
}
String runTime = "Pending", speed = "";
if( info._job.start_time != 0 ) {
runTime = PrettyPrint.msecs(info._job.runTimeMs(), true);
speed = perf != null ? PrettyPrint.msecs(info._job.speedValue(), true) : "";
}
sb.append("<td>").append(runTime).append("</td>");
if( perf != null )
sb.append("<td>").append(speed).append("</td>");
String link = "";
if( info._job.start_time != 0 && DKV.get(info._job.dest()) != null ) {
link = info._job.dest().toString();
if( info._model instanceof GBMModel )
link = GBMModelView.link(link, info._job.dest());
else if( info._model instanceof DRFModel )
link = DRFModelView.link(link, info._job.dest());
else if( info._model instanceof NeuralNetModel )
link = NeuralNetModelView.link(link, info._job.dest());
else if( info._model instanceof DeepLearningModel)
link = DeepLearningModelView.link(link, info._job.dest());
if( info._model instanceof KMeans2Model )
link = KMeans2ModelView.link(link, info._job.dest());
if (info._model instanceof SpeeDRFModel)
link = SpeeDRFModelView.link(link, info._job.dest());
else
link = Inspect2.link(link, info._job.dest());
}
sb.append("<td>").append(link).append("</td>");
String err, f1 = "";
if( info._cm != null && info._cm._arr != null) {
err = String.format("%.2f", 100 * info._error) + "%";
if (info._cm.isBinary()) f1 = String.format("%.4f", info._cm.F1());
} else err = String.format("%.5f", info._error) ;
sb.append("<td><b>").append(err).append("</b></td>");
sb.append("<td><b>").append(f1).append("</b></td>");
sb.append("</tr>");
}
DocGen.HTML.arrayTail(sb);
}
return true;
}
static class JobInfo {
Job _job;
Model _model;
ConfusionMatrix _cm;
double _error = Double.POSITIVE_INFINITY;
}
static void filter(ArrayList<Argument> args, String... names) {
for( String name : names )
for( int i = args.size() - 1; i >= 0; i-- )
if( args.get(i)._name.equals(name) )
args.remove(i);
}
@Override protected Response jobDone(final Key dst) {
return Response.done(this);
}
}
}