package water.util; import hex.NFoldFrameExtractor; import water.*; import water.fvec.Frame; import water.fvec.Vec; public class CrossValUtils { /** * Cross-Validate a ValidatedJob * @param job (must contain valid entries for n_folds, validation, destination_key, source, response) */ public static void crossValidate(Job.ValidatedJob job) { if (job.state != Job.JobState.RUNNING) return; //don't do cross-validation if the full model builder failed if (job.validation != null) throw new IllegalArgumentException("Cannot provide validation dataset and n_folds > 0 at the same time."); if (job.n_folds <= 1) throw new IllegalArgumentException("n_folds must be >= 2 for cross-validation."); final String basename = job.destination_key.toString(); long[] offsets = new long[job.n_folds +1]; Frame[] cv_preds = new Frame[job.n_folds]; try { for (int i = 0; i < job.n_folds; ++i) { if (job.state != Job.JobState.RUNNING) break; Key[] destkeys = new Key[]{Key.make(basename + "_xval" + i + "_train"), Key.make(basename + "_xval" + i + "_holdout")}; NFoldFrameExtractor nffe = new NFoldFrameExtractor(job.source, job.n_folds, i, destkeys, Key.make() /*key used for locking only*/); H2O.submitTask(nffe); Frame[] splits = nffe.getResult(); // Cross-validate individual splits try { job.crossValidate(splits, cv_preds, offsets, i); //this removes the enum-ified response! job._cv_count++; } finally { // clean-up the results if (!job.keep_cross_validation_splits) for(Frame f : splits) f.delete(); } } if (job.state != Job.JobState.RUNNING) return; final int resp_idx = job.source.find(job._responseName); Vec response = job.source.vecs()[resp_idx]; boolean put_back = UKV.get(job.response._key) == null; // In the case of rebalance, rebalance response will be deleted if (put_back) { job.response = response; if (job.classification) job.response = job.response.toEnum(); DKV.put(job.response._key, job.response); //put enum-ified response back to K-V store } ((Model)UKV.get(job.destination_key)).scoreCrossValidation(job, job.source, response, cv_preds, offsets); if (put_back) UKV.remove(job.response._key); } finally { // clean-up prediction frames for splits for(Frame f: cv_preds) if (f!=null) f.delete(); } } }