package water; import hex.FrameSplitter; import static water.util.Utils.difference; import static water.util.Utils.isEmpty; import java.io.PrintWriter; import java.io.StringWriter; import java.util.Arrays; import java.util.HashMap; import water.H2O.H2OCountedCompleter; import water.H2O.H2OEmptyCompleter; import water.api.*; import water.api.Request.Validator.NOPValidator; import water.api.RequestServer.API_VERSION; import water.fvec.Frame; import water.fvec.Vec; import water.util.*; import water.util.Utils.ExpectedExceptionForDebug; import dontweave.gson.*; public abstract class Job extends Func { static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code. /** A system key for global list of Job keys. */ public static final Key LIST = Key.make(Constants.BUILT_IN_KEY_JOBS, (byte) 0, Key.BUILT_IN_KEY); /** Shared empty int array. */ private static final int[] EMPTY = new int[0]; @API(help = "Job key") public Key job_key; @API(help = "Destination key", filter = Default.class, json = true, validator = DestKeyValidator.class) public Key destination_key; // Key holding final value after job is removed static class DestKeyValidator extends NOPValidator<Key> { @Override public void validateRaw(String value) { if (Utils.contains(value, Key.ILLEGAL_USER_KEY_CHARS)) throw new IllegalArgumentException("Key '" + value + "' contains illegal character! Please avoid these characters: " + Key.ILLEGAL_USER_KEY_CHARS); } } // Output parameters @API(help = "Job description") public String description; @API(help = "Job start time") public long start_time; @API(help = "Job end time") public long end_time; @API(help = "Exception") public String exception; @API(help = "Job state") public JobState state; transient public H2OCountedCompleter _fjtask; // Top-level task you can block on transient protected boolean _cv; /** Possible job states. */ public static enum JobState { CREATED, // Job was created RUNNING, // Job is running CANCELLED, // Job was cancelled by user FAILED, // Job crashed, error message/exception is available DONE // Job was successfully finished } public Job(Key jobKey, Key dstKey){ job_key = jobKey; destination_key = dstKey; state = JobState.CREATED; } public Job() { job_key = defaultJobKey(); description = getClass().getSimpleName(); state = JobState.CREATED; } /** Private copy constructor used by {@link JobHandle}. */ private Job(final Job prior) { this(prior.job_key, prior.destination_key); this.description = prior.description; this.start_time = prior.start_time; this.end_time = prior.end_time; this.state = prior.state; this.exception = prior.exception; } public Key self() { return job_key; } public Key dest() { return destination_key; } public int gridParallelism() { return 1; } protected Key defaultJobKey() { // Pinned to this node (i.e., the node invoked computation), because it should be almost always updated locally return Key.make((byte) 0, Key.JOB, H2O.SELF); } protected Key defaultDestKey() { return Key.make(getClass().getSimpleName() + Key.rand()); } /** Start this task based on given top-level fork-join task representing job computation. * @param fjtask top-level job computation task. * @return this job in {@link JobState#RUNNING} state * * @see JobState * @see H2OCountedCompleter */ public /** FIXME: should be final or at least protected */ Job start(final H2OCountedCompleter fjtask) { assert state == JobState.CREATED : "Trying to run job which was already run?"; assert fjtask != null : "Starting a job with null working task is not permitted! Fix you API"; _fjtask = fjtask; start_time = System.currentTimeMillis(); state = JobState.RUNNING; // Save the full state of the job UKV.put(self(), this); // Update job list new TAtomic<List>() { @Override public List atomic(List old) { if( old == null ) old = new List(); Key[] jobs = old._jobs; old._jobs = Arrays.copyOf(jobs, jobs.length + 1); old._jobs[jobs.length] = job_key; return old; } }.invoke(LIST); return this; } /** Return progress of this job. * * @return the value in interval <0,1> representing job progress. */ public float progress() { Freezable f = UKV.get(destination_key); if( f instanceof Progress ) return ((Progress) f).progress(); return 0; } /** Blocks and get result of this job. * <p> * The call blocks on working task which was passed via {@link #start(H2OCountedCompleter)} method * and returns the result which is fetched from UKV based on job destination key. * </p> * @return result of this job fetched from UKV by destination key. * @see #start(H2OCountedCompleter) * @see UKV */ public <T> T get() { _fjtask.join(); // Block until top-level job is done T ans = (T) UKV.get(destination_key); remove(); // Remove self-job return ans; } /** Signal cancellation of this job. * <p>The job will be switched to state {@link JobState#CANCELLED} which signals that * the job was cancelled by a user. */ public void cancel() { cancel((String)null, JobState.CANCELLED); } /** Signal exceptional cancellation of this job. * @param ex exception causing the termination of job. */ public void cancel(Throwable ex){ if(ex instanceof JobCancelledException || ex.getMessage() != null && ex.getMessage().contains("job was cancelled")) return; if(ex instanceof IllegalArgumentException || ex.getCause() instanceof IllegalArgumentException) { cancel("Illegal argument: " + ex.getMessage()); return; } StringWriter sw = new StringWriter(); PrintWriter pw = new PrintWriter(sw); ex.printStackTrace(pw); String stackTrace = sw.toString(); cancel("Got exception '" + ex.getClass() + "', with msg '" + ex.getMessage() + "'\n" + stackTrace, JobState.FAILED); if(_fjtask != null && !_fjtask.isDone()) _fjtask.completeExceptionally(ex); } /** Signal exceptional cancellation of this job. * @param msg cancellation message explaining reason for cancelation */ public void cancel(final String msg) { JobState js = msg == null ? JobState.CANCELLED : JobState.FAILED; cancel(msg, js); } private void cancel(final String msg, JobState resultingState ) { if(resultingState == JobState.CANCELLED) { Log.info("Job " + self() + "(" + description + ") was cancelled."); } else { Log.err("Job " + self() + "(" + description + ") failed."); Log.err(msg); } exception = msg; state = resultingState; // replace finished job by a job handle replaceByJobHandle(); DKV.write_barrier(); final Job job = this; H2O.submitTask(new H2OCountedCompleter() { @Override public void compute2() { job.onCancelled(); } }); } /** * Callback which is called after job cancellation (by user, by exception). */ protected void onCancelled() { } /** Returns true if the job was cancelled by the user or crashed. * @return true if the job is in state {@link JobState#CANCELLED} or {@link JobState#FAILED} */ public boolean isCancelledOrCrashed() { return state == JobState.CANCELLED || state == JobState.FAILED; } /** Returns true if the job was terminated by unexpected exception. * @return true, if the job was terminated by unexpected exception. */ public boolean isCrashed() { return state == JobState.FAILED; } /** Returns true if this job is correctly finished. * @return returns true if the job finished and it was not cancelled or crashed by an exception. */ public boolean isDone() { return state == JobState.DONE; } /** Returns true if this job is running * @return returns true only if this job is in running state. */ public boolean isRunning() { return state == JobState.RUNNING; } public JobState getState() { return state; } /** Returns a list of all jobs in a system. * @return list of all jobs including running, done, cancelled, crashed jobs. */ public static Job[] all() { List list = UKV.get(LIST); Job[] jobs = new Job[list==null?0:list._jobs.length]; int j=0; for( int i=0; i<jobs.length; i++ ) { Job job = UKV.get(list._jobs[i]); if( job != null ) jobs[j++] = job; } if( j<jobs.length ) jobs = Arrays.copyOf(jobs,j); return jobs; } /** Check if given job is running. * * @param job_key job key * @return true if job is still running else returns false. */ public static boolean isRunning(Key job_key) { Job j = UKV.get(job_key); assert j!=null : "Job should be always in DKV!"; return j.isRunning(); } /** * Returns true if job is not running. * The job can be cancelled, crashed, or already done. * * @param jobkey job identification key * @return true if job is done, cancelled, or crashed, else false */ public static boolean isEnded(Key jobkey) { return !isRunning(jobkey); } /** * Marks job as finished and records job end time. */ public void remove() { end_time = System.currentTimeMillis(); if( state == JobState.RUNNING ) state = JobState.DONE; // Overwrite handle - copy end_time, state, msg replaceByJobHandle(); } /** Finds a job with given key or returns null. * * @param jobkey job key * @return returns a job with given job key or null if a job is not found. */ public static Job findJob(final Key jobkey) { return UKV.get(jobkey); } /** Finds a job with given dest key or returns null */ public static Job findJobByDest(final Key destKey) { Job job = null; for( Job current : Job.all() ) { if( current.dest().equals(destKey) ) { job = current; break; } } return job; } /** Returns job execution time in milliseconds. * If job is not running then returns job execution time. */ public final long runTimeMs() { long until = end_time != 0 ? end_time : System.currentTimeMillis(); return until - start_time; } /** Description of a speed criteria: msecs/frob */ public String speedDescription() { return null; } /** Value of the described speed criteria: msecs/frob */ public long speedValue() { return 0; } @Override protected Response serve() { fork(); return redirect(); } protected Response redirect() { return Progress2.redirect(this, job_key, destination_key); } /** * Forks computation of this job. * * <p>The call does not block.</p> * @return always returns this job. */ public Job fork() { init(); H2OCountedCompleter task = new H2OCountedCompleter() { @Override public void compute2() { try { try { // Exec always waits till the end of computation Job.this.exec(); Job.this.remove(); } catch (Throwable t) { if(!(t instanceof ExpectedExceptionForDebug)) Log.err(t); Job.this.cancel(t); } } finally { tryComplete(); } } }; start(task); H2O.submitTask(task); return this; } @Override public void invoke() { init(); start(new H2OEmptyCompleter()); // mark job started exec(); // execute the implementation remove(); // remove the job } /** * Invoked before job runs. This is the place to checks arguments are valid or throw * IllegalArgumentException. It will get invoked both from the Web and Java APIs. * * @throws IllegalArgumentException throws the exception if initialization fails to ensure * correct job runtime environment. */ @Override protected void init() throws IllegalArgumentException { if (destination_key == null) destination_key = defaultDestKey(); } /** * Block synchronously waiting for a job to end, success or not. * @param jobkey Job to wait for. * @param pollingIntervalMillis Polling interval sleep time. */ public static void waitUntilJobEnded(Key jobkey, int pollingIntervalMillis) { while (true) { if (Job.isEnded(jobkey)) { return; } try { Thread.sleep (pollingIntervalMillis); } catch (Exception ignore) {} } } /** * Block synchronously waiting for a job to end, success or not. * @param jobkey Job to wait for. */ public static void waitUntilJobEnded(Key jobkey) { int THREE_SECONDS_MILLIS = 3 * 1000; waitUntilJobEnded(jobkey, THREE_SECONDS_MILLIS); } public static class ChunkProgress extends Iced implements Progress { final long _nchunks; final long _count; private final Status _status; final String _error; public enum Status { Computing, Done, Cancelled, Error } public Status status() { return _status; } public boolean isDone() { return _status == Status.Done || _status == Status.Error; } public String error() { return _error; } public ChunkProgress(long chunksTotal) { _nchunks = chunksTotal; _count = 0; _status = Status.Computing; _error = null; } private ChunkProgress(long nchunks, long computed, Status s, String err) { _nchunks = nchunks; _count = computed; _status = s; _error = err; } public ChunkProgress update(int count) { if( _status == Status.Cancelled || _status == Status.Error ) return this; long c = _count + count; return new ChunkProgress(_nchunks, c, Status.Computing, null); } public ChunkProgress done() { return new ChunkProgress(_nchunks, _nchunks, Status.Done, null); } public ChunkProgress cancel() { return new ChunkProgress(0, 0, Status.Cancelled, null); } public ChunkProgress error(String msg) { return new ChunkProgress(0, 0, Status.Error, msg); } @Override public float progress() { if( _status == Status.Done ) return 1.0f; return Math.min(0.99f, (float) ((double) _count / (double) _nchunks)); } } public static class ChunkProgressJob extends Job { Key _progress; public ChunkProgressJob(long chunksTotal, Key destinationKey) { destination_key = destinationKey; _progress = Key.make(Key.make()._kb, (byte) 0, Key.DFJ_INTERNAL_USER, destinationKey.home_node()); UKV.put(_progress, new ChunkProgress(chunksTotal)); } public void updateProgress(final int c) { // c == number of processed chunks if( isRunning(self()) ) { new TAtomic<ChunkProgress>() { @Override public ChunkProgress atomic(ChunkProgress old) { if( old == null ) return null; return old.update(c); } }.fork(_progress); } } @Override public void remove() { super.remove(); UKV.remove(_progress); } public final Key progressKey() { return _progress; } public void onException(Throwable ex) { UKV.remove(dest()); Value v = DKV.get(progressKey()); if( v != null ) { ChunkProgress p = v.get(); p = p.error(ex.getMessage()); DKV.put(progressKey(), p); } cancel(ex); } } public static boolean checkIdx(Frame source, int[] idx) { for (int i : idx) if (i<0 || i>source.vecs().length-1) return false; return true; } /* Update end_time, state, msg, preserve start_time */ private void replaceByJobHandle() { assert state != JobState.RUNNING : "Running job cannot be replaced."; final Job self = this; new TAtomic<Job>() { @Override public Job atomic(Job old) { if( old == null ) return null; JobHandle jh = new JobHandle(self); jh.start_time = old.start_time; return jh; } }.fork(job_key); } /** * A job which operates with a frame. * * INPUT frame */ public static abstract class FrameJob extends Job { static final int API_WEAVER = 1; static public DocGen.FieldDoc[] DOC_FIELDS; @API(help = "Source frame", required = true, filter = Default.class, json = true) public Frame source; /** * Annotate the number of columns and rows of the training data set in the job parameter JSON * @return JsonObject annotated with num_cols and num_rows of the training data set */ @Override public JsonObject toJSON() { JsonObject jo = super.toJSON(); if (source != null) { jo.getAsJsonObject("source").addProperty("num_cols", source.numCols()); jo.getAsJsonObject("source").addProperty("num_rows", source.numRows()); } return jo; } } /** * A job which has an input represented by a frame and frame column filter. * The filter can be specified by ignored columns or by used columns. * * INPUT list ignored columns by idx XOR list of ignored columns by name XOR list of used columns * * @see FrameJob */ public static abstract class ColumnsJob extends FrameJob { static final int API_WEAVER = 1; static public DocGen.FieldDoc[] DOC_FIELDS; @API(help = "Input columns (Indexes start at 0)", filter=colsFilter.class, hide=true) public int[] cols; class colsFilter extends MultiVecSelect { public colsFilter() { super("source"); } } @API(help = "Ignored columns by name and zero-based index", filter=colsNamesIdxFilter.class, displayName="Ignored columns") public int[] ignored_cols; class colsNamesIdxFilter extends MultiVecSelect { public colsNamesIdxFilter() {super("source", MultiVecSelectType.NAMES_THEN_INDEXES); } } @API(help = "Ignored columns by name", filter=colsNamesFilter.class, displayName="Ignored columns by name", hide=true) public int[] ignored_cols_by_name = EMPTY; class colsNamesFilter extends MultiVecSelect { public colsNamesFilter() {super("source", MultiVecSelectType.NAMES_ONLY); } } /** * Annotate the used and ignored columns in the job parameter JSON * For both the used and the ignored columns, the following rules apply: * If the number of columns is less or equal than 100, a dense list of used columns is reported. * If the number of columns is greater than 100, the number of columns is reported. * If the number of columns is 0, a "N/A" is reported. * @return JsonObject annotated with used/ignored columns */ @Override public JsonObject toJSON() { JsonObject jo = super.toJSON(); if (!jo.has("source") || source==null) return jo; HashMap<String, int[]> map = new HashMap<String, int[]>(); map.put("used_cols", cols); map.put("ignored_cols", ignored_cols); for (String key : map.keySet()) { int[] val = map.get(key); if (val != null) { if(val.length>100) jo.getAsJsonObject("source").addProperty("num_" + key, val.length); else if(val.length>0) { StringBuilder sb = new StringBuilder(); for (int c : val) sb.append(c + ","); jo.getAsJsonObject("source").addProperty(key, sb.toString().substring(0, sb.length()-1)); } else { jo.getAsJsonObject("source").add(key, JsonNull.INSTANCE); } } } return jo; } @Override protected void init() { super.init(); if (_cv) return; // At most one of the following may be specified. int specified = 0; if (!isEmpty(cols)) { specified++; } if (!isEmpty(ignored_cols)) { specified++; } if (!isEmpty(ignored_cols_by_name)) { specified++; } if (specified > 1) throw new IllegalArgumentException("Arguments 'cols', 'ignored_cols_by_name', and 'ignored_cols' are exclusive"); // Unify all ignored cols specifiers to ignored_cols. { if (!isEmpty(ignored_cols_by_name)) { assert (isEmpty(ignored_cols)); ignored_cols = ignored_cols_by_name; ignored_cols_by_name = EMPTY; } if (ignored_cols == null) { ignored_cols = new int[0]; } } // At this point, ignored_cols_by_name is dead. assert (isEmpty(ignored_cols_by_name)); // Create map of ignored columns for speed. HashMap<Integer,Integer> ignoredColsMap = new HashMap<Integer,Integer>(); for ( int i = 0; i < ignored_cols.length; i++) { int value = ignored_cols[i]; ignoredColsMap.put(new Integer(value), new Integer(1)); } // Add UUID cols to ignoredColsMap. Duplicates get folded into one entry. Vec[] vecs = source.vecs(); for( int i = 0; i < vecs.length; i++ ) { if (vecs[i].isUUID()) { ignoredColsMap.put(new Integer(i), new Integer(1)); } } // Rebuild ignored_cols from the map. Sort it. { ignored_cols = new int[ignoredColsMap.size()]; int j = 0; for (Integer key : ignoredColsMap.keySet()) { ignored_cols[j] = key.intValue(); j++; } Arrays.sort(ignored_cols); } // If the columns are not specified, then select everything. if (isEmpty(cols)) { cols = new int[source.vecs().length]; for( int i = 0; i < cols.length; i++ ) cols[i] = i; } else { if (!checkIdx(source, cols)) throw new IllegalArgumentException("Argument 'cols' specified invalid column!"); } // Make a set difference between cols and ignored_cols. if (!isEmpty(ignored_cols)) { int[] icols = ! isEmpty(ignored_cols) ? ignored_cols : ignored_cols_by_name; if (!checkIdx(source, icols)) throw new IllegalArgumentException("Argument 'ignored_cols' or 'ignored_cols_by_name' specified invalid column!"); cols = difference(cols, icols); // Setup all variables in consistent way ignored_cols = icols; ignored_cols_by_name = icols; } if( cols.length == 0 ) { throw new IllegalArgumentException("No column selected"); } } protected final Vec[] selectVecs(Frame frame) { Vec[] vecs = new Vec[cols.length]; for( int i = 0; i < cols.length; i++ ) vecs[i] = frame.vecs()[cols[i]]; return vecs; } protected final Frame selectFrame(Frame frame) { Vec[] vecs = new Vec[cols.length]; String[] names = new String[cols.length]; for( int i = 0; i < cols.length; i++ ) { vecs[i] = frame.vecs()[cols[i]]; names[i] = frame.names()[cols[i]]; } return new Frame(names, vecs); } } /** * A columns job that requires a response. * * INPUT response column from source */ public static abstract class ColumnsResJob extends ColumnsJob { static final int API_WEAVER = 1; static public DocGen.FieldDoc[] DOC_FIELDS; @API(help="Column to use as class", required=true, filter=responseFilter.class, json = true) public Vec response; class responseFilter extends VecClassSelect { responseFilter() { super("source"); } } @Override protected void registered(API_VERSION ver) { super.registered(ver); Argument c = find("ignored_cols"); Argument r = find("response"); int ci = _arguments.indexOf(c); int ri = _arguments.indexOf(r); _arguments.set(ri, c); _arguments.set(ci, r); ((FrameKeyMultiVec) c).ignoreVec((FrameKeyVec)r); } /** * Annotate the name of the response column in the job parameter JSON * @return JsonObject annotated with the name of the response column */ @Override public JsonObject toJSON() { JsonObject jo = super.toJSON(); if (source!=null) { int idx = source.find(response); if( idx == -1 ) { Vec vm = response.masterVec(); if( vm != null ) idx = source.find(vm); } jo.getAsJsonObject("response").add("name", new JsonPrimitive(idx == -1 ? "null" : source._names[idx])); } return jo; } @Override protected void init() { super.init(); // Check if it make sense to build a model if (source.numRows()==0) throw new H2OIllegalArgumentException(find("source"), "Cannot build a model on empty dataset!"); // Does not alter the Response to an Enum column if Classification is // asked for: instead use the classification flag to decide between // classification or regression. Vec[] vecs = source.vecs(); for( int i = cols.length - 1; i >= 0; i-- ) if( vecs[cols[i]] == response ) cols = Utils.remove(cols,i); final boolean has_constant_response = response.isEnum() ? response.domain().length <= 1 : response.min() == response.max(); if (has_constant_response) throw new H2OIllegalArgumentException(find("response"), "Constant response column!"); } } /** * A job producing a model. * * INPUT response column from source */ public static abstract class ModelJob extends ModelJobWithoutClassificationField { static final int API_WEAVER = 1; static public DocGen.FieldDoc[] DOC_FIELDS; @API(help="Do classification or regression", filter=myClassFilter.class, json = true) public boolean classification = true; // we need 3-state boolean: unspecified, true/false BUT we solve that by checking UI layer to see if the classification parameter was passed class myClassFilter extends DoClassBoolean { myClassFilter() { super("source"); } } @Override protected void init() { super.init(); // Reject request if classification is required and response column is float //Argument a4class = find("classification"); // get UI control //String p4class = input("classification"); // get value from HTTP requests // if there is UI control and classification field was passed final boolean classificationFieldSpecified = true; // ROLLBACK: a4class!=null ? p4class!=null : /* we are not in UI so expect that parameter is specified correctly */ true; if (!classificationFieldSpecified) { // can happen if a client sends a request which does not specify classification parameter classification = response.isEnum(); Log.warn("Classification field is not specified - deriving according to response! The classification field set to " + classification); } else { if ( classification && response.isFloat()) throw new H2OIllegalArgumentException(find("classification"), "Requested classification on float column!"); if (!classification && response.isEnum() ) throw new H2OIllegalArgumentException(find("classification"), "Requested regression on enum column!"); } } } /** * A job producing a model that has no notion of Classification or Regression. * * INPUT response column from source */ public static abstract class ModelJobWithoutClassificationField extends ColumnsResJob { // This exists to support GLM2, which determines classification/regression using the // family field, not a second separate field. } /** * Job which produces model and validate it on a given dataset. * INPUT validation frame */ public static abstract class ValidatedJob extends ModelJob { static final int API_WEAVER = 1; static public DocGen.FieldDoc[] DOC_FIELDS; protected transient Vec[] _train, _valid; /** Validation vector extracted from validation frame. */ protected transient Vec _validResponse; /** Validation response domain or null if validation is not specified or null if response is float. */ protected transient String[] _validResponseDomain; /** Source response domain or null if response is float. */ protected transient String[] _sourceResponseDomain; /** CM domain derived from {@link #_validResponseDomain} and {@link #_sourceResponseDomain}. */ protected transient String[] _cmDomain; /** Names of columns */ protected transient String[] _names; /** Name of validation response. Should be same as source response. */ public transient String _responseName; /** Adapted validation frame to a computed model. */ private transient Frame _adaptedValidation; private transient Vec _adaptedValidationResponse; // Validation response adapted to computed CM domain private transient int[][] _fromModel2CM; // Transformation for model response to common CM domain private transient int[][] _fromValid2CM; // Transformation for validation response to common CM domain @API(help = "Validation frame", filter = Default.class, mustExist = true, json = true) public Frame validation; @API(help = "Number of folds for cross-validation (if no validation data is specified)", filter = Default.class, json = true) public int n_folds = 0; @API(help = "Fraction of training data (from end) to hold out for validation (if no validation data is specified)", filter = Default.class, json = true) public float holdout_fraction = 0; @API(help = "Keep cross-validation dataset splits", filter = Default.class, json = true) public boolean keep_cross_validation_splits = false; @API(help = "Cross-validation models", json = true) public Key[] xval_models; public int _cv_count = 0; /** * Helper to compute the actual progress if we're doing cross-validation. * This method is supposed to be called by the progress() implementation for CV-capable algos. * @param p Progress reported by the main job * @return actual progress if CV is done, otherwise returns p */ public float cv_progress(float p) { if (n_folds >= 2) { return (p + _cv_count) / (n_folds + 1); //divide by 1 more to account for final scoring as extra work } return p; } /** * Helper to specify which arguments trigger a refresh on change * @param ver */ @Override protected void registered(RequestServer.API_VERSION ver) { super.registered(ver); for (Argument arg : _arguments) { if ( arg._name.equals("validation")) { arg.setRefreshOnChange(); } } } /** * Helper to handle arguments based on existing input values * @param arg * @param inputArgs */ @Override protected void queryArgumentValueSet(Argument arg, java.util.Properties inputArgs) { super.queryArgumentValueSet(arg, inputArgs); if (arg._name.equals("n_folds") && validation != null) { arg.disable("Only if no validation dataset is provided."); n_folds = 0; } } /** * Cross-Validate this Job (to be overridden for each instance, which also calls genericCrossValidation) * @param splits Frames containing train/test splits * @param cv_preds Store the predictions for each cross-validation run * @param offsets Array to store the offsets of starting row indices for each cross-validation run * @param i Which fold of cross-validation to perform */ public void crossValidate(Frame[] splits, Frame[] cv_preds, long[] offsets, int i) { throw H2O.unimpl(); } /** * Helper to perform the generic part of cross validation * Expected to be called from each specific instance's crossValidate method * @param splits Frames containing train/test splits * @param offsets Array to store the offsets of starting row indices for each cross-validation run * @param i Which fold of cross-validation to perform */ final protected void genericCrossValidation(Frame[] splits, long[] offsets, int i) { int respidx = source.find(_responseName); assert(respidx != -1) : "response is not found in source!"; job_key = Key.make(job_key.toString() + "_xval" + i); //make a new Job for CV assert(xval_models != null); destination_key = xval_models[i]; source = splits[0]; validation = splits[1]; response = source.vecs()[respidx]; n_folds = 0; state = Job.JobState.CREATED; //Hack to allow this job to run DKV.put(self(), this); //Needed to pass the Job.isRunning(cvdl.self()) check in FrameTask offsets[i + 1] = offsets[i] + validation.numRows(); _cv = true; //Hack to allow init() to pass for ColumnsJob (allow cols/ignored_cols to co-exist) invoke(); } /** * Annotate the number of columns and rows of the validation data set in the job parameter JSON * @return JsonObject annotated with num_cols and num_rows of the validation data set */ @Override public JsonObject toJSON() { JsonObject jo = super.toJSON(); if (validation != null) { jo.getAsJsonObject("validation").addProperty("num_cols", validation.numCols()); jo.getAsJsonObject("validation").addProperty("num_rows", validation.numRows()); } return jo; } @Override protected void init() { if ( validation != null && n_folds != 0 ) throw new UnsupportedOperationException("Cannot specify a validation dataset and non-zero number of cross-validation folds."); if ( n_folds < 0 ) throw new UnsupportedOperationException("The number of cross-validation folds must be >= 0."); super.init(); xval_models = new Key[n_folds]; for (int i=0; i<xval_models.length; ++i) xval_models[i] = Key.make(dest().toString() + "_xval" + i); int rIndex = 0; for( int i = 0; i < source.vecs().length; i++ ) if( source.vecs()[i] == response ) { rIndex = i; break; } _responseName = source._names != null && rIndex >= 0 ? source._names[rIndex] : "response"; if (holdout_fraction > 0) { if (holdout_fraction >= 1) throw new IllegalArgumentException("Holdout fraction must be less than 1."); if (validation != null) throw new IllegalArgumentException("Cannot specify both a holdout fraction and a validation frame."); if (n_folds != 0) throw new IllegalArgumentException("Cannot specify both a holdout fraction and a n-fold cross-validation."); Log.info("Holding out last " + Utils.formatPct(holdout_fraction) + " of training data."); FrameSplitter fs = new FrameSplitter(source, new float[]{1 - holdout_fraction}); H2O.submitTask(fs).join(); Frame[] splits = fs.getResult(); source = splits[0]; response = source.vecs()[rIndex]; validation = splits[1]; Log.warn("Allocating data split frames: " + source._key.toString() + " and " + validation._key.toString()); Log.warn("Both will be kept after the the model is trained. It's the user's responsibility to manage their lifetime."); } _train = selectVecs(source); _names = new String[cols.length]; for( int i = 0; i < cols.length; i++ ) _names[i] = source._names[cols[i]]; // Compute source response domain if (classification) _sourceResponseDomain = getVectorDomain(response); // Is validation specified? if( validation != null ) { // Extract a validation response int idx = validation.find(source.names()[rIndex]); if( idx == -1 ) throw new IllegalArgumentException("Validation set does not have a response column called "+_responseName); _validResponse = validation.vecs()[idx]; // Compute output confusion matrix domain for classification: // - if validation dataset is specified then CM domain is union of train and validation response domains // else it is only domain of response column. if (classification) { _validResponseDomain = getVectorDomain(_validResponse); if (_validResponseDomain!=null) { _cmDomain = Utils.domainUnion(_sourceResponseDomain, _validResponseDomain); if (!Arrays.deepEquals(_sourceResponseDomain, _validResponseDomain)) { _fromModel2CM = Model.getDomainMapping(_cmDomain, _sourceResponseDomain, false); // transformation from model produced response ~> cmDomain _fromValid2CM = Model.getDomainMapping(_cmDomain, _validResponseDomain , false); // transformation from validation response domain ~> cmDomain } } else _cmDomain = _sourceResponseDomain; } /* end of if classification */ } else if (classification) _cmDomain = _sourceResponseDomain; } protected String[] getVectorDomain(final Vec v) { assert v==null || v.isInt() || v.isEnum() : "Cannot get vector domain!"; if (v==null) return null; String[] r; if (v.isEnum()) { r = v.domain(); } else { Vec tmp = v.toEnum(); r = tmp.domain(); UKV.remove(tmp._key); } return r; } /** Returns true if the job has specified validation dataset. */ protected final boolean hasValidation() { return validation!=null; } /** Returns a domain for confusion matrix. */ protected final String[] getCMDomain() { return _cmDomain; } /** Return validation dataset which can be adapted to a model if it is necessary. */ protected final Frame getValidation() { return _adaptedValidation!=null ? _adaptedValidation : validation; }; /** Returns original validation dataset. */ protected final Frame getOrigValidation() { return validation; } public final Response2CMAdaptor getValidAdaptor() { return new Response2CMAdaptor(); } /** */ protected final void prepareValidationWithModel(final Model model) { if (validation == null) return; Frame[] av = model.adapt(validation, false); _adaptedValidation = av[0]; gtrash(av[1]); // delete this after computation if (_fromValid2CM!=null) { assert classification : "Validation response transformation should be declared only for classification!"; assert _fromModel2CM != null : "Model response transformation should exist if validation response transformation exists!"; Vec tmp = _validResponse.toEnum(); _adaptedValidationResponse = tmp.makeTransf(_fromValid2CM, getCMDomain()); // Add an original response adapted to CM domain gtrash(_adaptedValidationResponse); // Add the created vector to a clean-up list gtrash(tmp); } } /** A micro helper for transforming model/validation responses to confusion matrix domain. */ public class Response2CMAdaptor { /** Adapt given vector produced by a model to confusion matrix domain. Always return a new vector which needs to be deleted. */ public Vec adaptModelResponse2CM(final Vec v) { return v.makeTransf(_fromModel2CM, getCMDomain()); } /** Adapt given validation vector to confusion matrix domain. Always return a new vector which needs to be deleted. */ public Vec adaptValidResponse2CM(final Vec v) { return v.makeTransf(_fromValid2CM, getCMDomain()); } /** Returns validation dataset. */ public Frame getValidation() { return ValidatedJob.this.getValidation(); } /** Return cached validation response already adapted to CM domain. */ public Vec getAdaptedValidationResponse2CM() { return _adaptedValidationResponse; } /** Return cm domain. */ public String[] getCMDomain() { return ValidatedJob.this.getCMDomain(); } /** Returns true if model/validation responses need to be adapted to confusion matrix domain. */ public boolean needsAdaptation2CM() { return _fromModel2CM != null; } /** Return the adapted response name */ public String adaptedValidationResponse(final String response) { return response + ".adapted"; } } } /** * */ public interface Progress { float progress(); } public interface ProgressMonitor { public void update(long n); } public static class Fail extends Iced { public final String _message; public Fail(String message) { _message = message; } } public static final class List extends Iced { Key[] _jobs = new Key[0]; @Override public List clone(){ List l = new List(); l._jobs = _jobs.clone(); for(int i = 0; i < l._jobs.length; ++i) l._jobs[i] = (Key)l._jobs[i].clone(); return l; } } /** Almost lightweight job handle containing the same content * as pure Job class. */ public static class JobHandle extends Job { public JobHandle(final Job job) { super(job); } } public static class JobCancelledException extends RuntimeException { public JobCancelledException(){super("job was cancelled!");} public JobCancelledException(String msg){super("job was cancelled! with msg '" + msg + "'");} } /** Hygienic method to prevent accidental capture of non desired values. */ public static <T extends FrameJob> T hygiene(T job) { job.source = null; return job; } public static <T extends ValidatedJob> T hygiene(T job) { job.source = null; job.validation = null; return job; } }