package hex; import hex.genmodel.GenModel; import hex.genmodel.MojoModel; import hex.genmodel.easy.EasyPredictModelWrapper; import hex.genmodel.easy.RowData; import hex.genmodel.easy.exception.PredictException; import hex.genmodel.easy.prediction.*; import hex.genmodel.utils.DistributionFamily; import org.joda.time.DateTime; import water.*; import water.api.StreamWriter; import water.api.StreamingSchema; import water.api.schemas3.KeyV3; import water.codegen.CodeGenerator; import water.codegen.CodeGeneratorPipeline; import water.exceptions.JCodeSB; import water.fvec.*; import water.parser.BufferedString; import water.util.*; import java.io.*; import java.lang.reflect.Field; import java.util.*; import static water.util.FrameUtils.categoricalEncoder; import static water.util.FrameUtils.cleanUp; /** * A Model models reality (hopefully). * A model can be used to 'score' a row (make a prediction), or a collection of * rows on any compatible dataset - meaning the row has all the columns with the * same names as used to build the mode and any categorical columns can * be adapted. */ public abstract class Model<M extends Model<M,P,O>, P extends Model.Parameters, O extends Model.Output> extends Lockable<M> { public P _parms; // TODO: move things around so that this can be protected public O _output; // TODO: move things around so that this can be protected public String[] _warnings = new String[0]; public Distribution _dist; protected ScoringInfo[] scoringInfo; public IcedHashMap<Key, String> _toDelete = new IcedHashMap<>(); public interface DeepFeatures { Frame scoreAutoEncoder(Frame frame, Key destination_key, boolean reconstruction_error_per_feature); Frame scoreDeepFeatures(Frame frame, final int layer); Frame scoreDeepFeatures(Frame frame, final int layer, final Job j); //for Deep Learning Frame scoreDeepFeatures(Frame frame, final String layer, final Job j); //for Deep Water } public interface GLRMArchetypes { Frame scoreReconstruction(Frame frame, Key<Frame> destination_key, boolean reverse_transform); Frame scoreArchetypes(Frame frame, Key<Frame> destination_key, boolean reverse_transform); } public interface LeafNodeAssignment { Frame scoreLeafNodeAssignment(Frame frame, Key<Frame> destination_key); } public interface ExemplarMembers { Frame scoreExemplarMembers(Key<Frame> destination_key, int exemplarIdx); } public interface GetMostImportantFeatures { String[] getMostImportantFeatures(int n); } /** * Default threshold for assigning class labels to the target class (for binomial models) * @return threshold in 0...1 */ public double defaultThreshold() { if (_output.nclasses() != 2 || _output._training_metrics == null) return 0.5; if (_output._validation_metrics != null && ((ModelMetricsBinomial)_output._validation_metrics)._auc != null) return ((ModelMetricsBinomial)_output._validation_metrics)._auc.defaultThreshold(); if (((ModelMetricsBinomial)_output._training_metrics)._auc != null) return ((ModelMetricsBinomial)_output._training_metrics)._auc.defaultThreshold(); return 0.5; } public final boolean isSupervised() { return _output.isSupervised(); } /** * Identifies the default ordering method for models returned from Grid Search * @return default sort-by */ public GridSortBy getDefaultGridSortBy() { if (! isSupervised()) return null; else if (_output.nclasses() > 1) return GridSortBy.LOGLOSS; else return GridSortBy.RESDEV; } public static class GridSortBy { // intentionally not an enum to allow 3rd party extensions public static final GridSortBy LOGLOSS = new GridSortBy("logloss", false); public static final GridSortBy RESDEV = new GridSortBy("residual_deviance", false); public static final GridSortBy R2 = new GridSortBy("r2", true); public final String _name; public final boolean _decreasing; GridSortBy(String name, boolean decreasing) { _name = name; _decreasing = decreasing; } } public ToEigenVec getToEigenVec() { return null; } /** Model-specific parameter class. Each model sub-class contains * instance of one of these containing its builder parameters, with * model-specific parameters. E.g. KMeansModel extends Model and has a * KMeansParameters extending Model.Parameters; sample parameters include K, * whether or not to normalize, max iterations and the initial random seed. * * <p>The non-transient fields are input parameters to the model-building * process, and are considered "first class citizens" by the front-end - the * front-end will cache Parameters (in the browser, in JavaScript, on disk) * and rebuild Parameter instances from those caches. * * WARNING: Model Parameters is not immutable object and ModelBuilder can modify * them! */ public abstract static class Parameters extends Iced<Parameters> { /** Maximal number of supported levels in response. */ public static final int MAX_SUPPORTED_LEVELS = 1<<20; /** The short name, used in making Keys. e.g. "GBM" */ abstract public String algoName(); /** The pretty algo name for this Model (e.g., Gradient Boosting Machine, rather than GBM).*/ abstract public String fullName(); /** The Java class name for this Model (e.g., hex.tree.gbm.GBM, rather than GBM).*/ abstract public String javaName(); /** Default relative tolerance for convergence-based early stopping */ protected double defaultStoppingTolerance() { return 1e-3; } /** How much work will be done for this model? */ abstract public long progressUnits(); public Key<Frame> _train; // User-Key of the Frame the Model is trained on public Key<Frame> _valid; // User-Key of the Frame the Model is validated on, if any public int _nfolds = 0; public boolean _keep_cross_validation_predictions = false; public boolean _keep_cross_validation_fold_assignment = false; public boolean _parallelize_cross_validation = true; public boolean _auto_rebalance = true; public void setTrain(Key<Frame> train) { this._train = train; } public enum FoldAssignmentScheme { AUTO, Random, Modulo, Stratified } public enum CategoricalEncodingScheme { AUTO(false), OneHotInternal(false), OneHotExplicit(false), Enum(false), Binary(false), Eigen(false), LabelEncoder(false), SortByResponse(true); CategoricalEncodingScheme(boolean needResponse) { _needResponse = needResponse; } final boolean _needResponse; boolean needsResponse() { return _needResponse; } } public long _seed = -1; public long getOrMakeRealSeed(){ while (_seed==-1) { _seed = RandomUtils.getRNG(System.nanoTime()).nextLong(); Log.debug("Auto-generated time-based seed for pseudo-random number generator (because it was set to -1): " + _seed); } return _seed; } public FoldAssignmentScheme _fold_assignment = FoldAssignmentScheme.AUTO; public CategoricalEncodingScheme _categorical_encoding = CategoricalEncodingScheme.AUTO; public DistributionFamily _distribution = DistributionFamily.AUTO; public double _tweedie_power = 1.5; public double _quantile_alpha = 0.5; public double _huber_alpha = 0.9; // TODO: This field belongs in the front-end column-selection process and // NOT in the parameters - because this requires all model-builders to have // column strip/ignore code. public String[] _ignored_columns; // column names to ignore for training public boolean _ignore_const_cols; // True if dropping constant cols public String _weights_column; public String _offset_column; public String _fold_column; public boolean _is_cv_model; //internal helper // Scoring a model on a dataset is not free; sometimes it is THE limiting // factor to model building. By default, partially built models are only // scored every so many major model iterations - throttled to limit scoring // costs to less than 10% of the build time. This flag forces scoring for // every iteration, allowing e.g. more fine-grained progress reporting. public boolean _score_each_iteration; /** * Maximum allowed runtime in seconds for model training. Use 0 to disable. */ public double _max_runtime_secs = 0; /** * Early stopping based on convergence of stopping_metric. * Stop if simple moving average of the stopping_metric does not improve by stopping_tolerance for * k scoring events. * Can only trigger after at least 2k scoring events. Use 0 to disable. */ public int _stopping_rounds = 0; /** * Metric to use for convergence checking, only for _stopping_rounds > 0. */ public ScoreKeeper.StoppingMetric _stopping_metric = ScoreKeeper.StoppingMetric.AUTO; /** * Relative tolerance for metric-based stopping criterion: stop if relative improvement is not at least this much. */ public double _stopping_tolerance = defaultStoppingTolerance(); /** Supervised models have an expected response they get to train with! */ public String _response_column; // response column name /** Should all classes be over/under-sampled to balance the class * distribution? */ public boolean _balance_classes = false; /** When classes are being balanced, limit the resulting dataset size to * the specified multiple of the original dataset size. Maximum relative * size of the training data after balancing class counts (can be less * than 1.0) */ public float _max_after_balance_size = 5.0f; /** * Desired over/under-sampling ratios per class (lexicographic order). * Only when balance_classes is enabled. * If not specified, they will be automatically computed to obtain class balance during training. */ public float[] _class_sampling_factors; /** For classification models, the maximum size (in terms of classes) of * the confusion matrix for it to be printed. This option is meant to * avoid printing extremely large confusion matrices. */ public int _max_confusion_matrix_size = 20; /** * A model key associated with a previously trained Deep Learning * model. This option allows users to build a new model as a * continuation of a previously generated model. */ public Key<? extends Model> _checkpoint; /** * A pretrained Autoencoder DL model with matching inputs and hidden layers * can be used to initialize the weights and biases (excluding the output layer). */ public Key<? extends Model> _pretrained_autoencoder; // Public no-arg constructor for reflective creation public Parameters() { _ignore_const_cols = defaultDropConsCols(); } /** @return the training frame instance */ public final Frame train() { return _train==null ? null : _train.get(); } /** @return the validation frame instance, or null * if a validation frame was not specified */ public final Frame valid() { return _valid==null ? null : _valid.get(); } /** Read-Lock both training and validation User frames. */ public void read_lock_frames(Job job) { Frame tr = train(); if (tr != null) tr.read_lock(job._key); if (_valid != null && !_train.equals(_valid)) _valid.get().read_lock(job._key); } /** Read-UnLock both training and validation User frames. This method is * called on crashing cleanup pathes, so handles the case where the frames * are not actually locked. */ public void read_unlock_frames(Job job) { Frame tr = train(); if( tr != null ) tr.unlock(job._key,false); if( _valid != null && !_train.equals(_valid) ) valid().unlock(job._key,false); } // Override in subclasses to change the default; e.g. true in GLM protected boolean defaultDropConsCols() { return true; } /** Type of missing columns during adaptation between train/test datasets * Overload this method for models that have sparse data handling - a zero * will preserve the sparseness. Otherwise, NaN is used. * @return real-valued number (can be NaN) */ public double missingColumnsType() { return Double.NaN; } public boolean hasCheckpoint() { return _checkpoint != null; } // FIXME: this is really horrible hack, Model.Parameters has method checksum_impl, // but not checksum, the API is totally random :( public long checksum() { return checksum_impl(); } /** * Compute a checksum based on all non-transient non-static ice-able assignable fields (incl. inherited ones) which have @API annotations. * Sort the fields first, since reflection gives us the fields in random order and we don't want the checksum to be affected by the field order. * NOTE: if a field is added to a Parameters class the checksum will differ even when all the previous parameters have the same value. If * a client wants backward compatibility they will need to compare parameter values explicitly. * * The method is motivated by standard hash implementation `hash = hash * P + value` but we use high prime numbers in random order. * @return checksum */ protected long checksum_impl() { long xs = 0x600DL; int count = 0; Field[] fields = Weaver.getWovenFields(this.getClass()); Arrays.sort(fields, new Comparator<Field>() { public int compare(Field field1, Field field2) { return field1.getName().compareTo(field2.getName()); } }); for (Field f : fields) { final long P = MathUtils.PRIMES[count % MathUtils.PRIMES.length]; Class<?> c = f.getType(); if (c.isArray()) { try { f.setAccessible(true); if (f.get(this) != null) { if (c.getComponentType() == Integer.TYPE){ int[] arr = (int[]) f.get(this); xs = xs * P + (long) Arrays.hashCode(arr); } else if (c.getComponentType() == Float.TYPE) { float[] arr = (float[]) f.get(this); xs = xs * P + (long) Arrays.hashCode(arr); } else if (c.getComponentType() == Double.TYPE) { double[] arr = (double[]) f.get(this); xs = xs * P + (long) Arrays.hashCode(arr); } else if (c.getComponentType() == Long.TYPE){ long[] arr = (long[]) f.get(this); xs = xs * P + (long) Arrays.hashCode(arr); } else { Object[] arr = (Object[]) f.get(this); xs = xs * P + (long) Arrays.deepHashCode(arr); } //else lead to ClassCastException } else { xs = xs * P; } } catch (IllegalAccessException e) { throw new RuntimeException(e); } catch (ClassCastException t) { throw H2O.fail(); //no support yet for int[][] etc. } } else { try { f.setAccessible(true); Object value = f.get(this); if (value != null) { xs = xs * P + (long)(value.hashCode()); } else { xs = xs * P + P; } } catch (IllegalAccessException e) { throw new RuntimeException(e); } } count++; } xs ^= (train() == null ? 43 : train().checksum()) * (valid() == null ? 17 : valid().checksum()); return xs; } } public ModelMetrics addModelMetrics(final ModelMetrics mm) { DKV.put(mm); incrementModelMetrics(_output, mm._key); return mm; } static void incrementModelMetrics(Output out, Key k) { synchronized(out) { for (Key key : out._model_metrics) if (k.equals(key)) return; out._model_metrics = Arrays.copyOf(out._model_metrics, out._model_metrics.length + 1); out._model_metrics[out._model_metrics.length - 1] = k; } } public void addWarning(String s){ _warnings = Arrays.copyOf(_warnings,_warnings.length+1); _warnings[_warnings.length-1] = s; } /** Model-specific output class. Each model sub-class contains an instance * of one of these containing its "output": the pieces of the model needed * for scoring. E.g. KMeansModel has a KMeansOutput extending Model.Output * which contains the cluster centers. The output also includes the names, * domains and other fields which are determined at training time. */ public abstract static class Output extends Iced { /** Columns used in the model and are used to match up with scoring data * columns. The last name is the response column name (if any). */ public String _names[]; public void setNames(String[] names) { _names = names; } public String _origNames[]; /** Categorical/factor mappings, per column. Null for non-categorical cols. * Columns match the post-init cleanup columns. The last column holds the * response col categoricals for SupervisedModels. */ public String _domains[][]; public String _origDomains[][]; /** List of Keys to cross-validation models (non-null iff _parms._nfolds > 1 or _parms._fold_column != null) **/ public Key _cross_validation_models[]; /** List of Keys to cross-validation predictions (if requested) **/ public Key _cross_validation_predictions[]; public Key<Frame> _cross_validation_holdout_predictions_frame_id; public Key<Frame> _cross_validation_fold_assignment_frame_id; // Model-specific start/end/run times // Each individual model's start/end/run time is reported here, not the total time to build N+1 cross-validation models, or all grid models public long _start_time; public long _end_time; public long _run_time; protected void startClock() { _start_time = System.currentTimeMillis(); } protected void stopClock() { _end_time = System.currentTimeMillis(); _run_time = _end_time - _start_time; } public Output(){this(false,false,false);} public Output(boolean hasWeights, boolean hasOffset, boolean hasFold) { _hasWeights = hasWeights; _hasOffset = hasOffset; _hasFold = hasFold; } /** Any final prep-work just before model-building starts, but after the * user has clicked "go". E.g., converting a response column to an categorical * touches the entire column (can be expensive), makes a parallel vec * (Key/Data leak management issues), and might throw IAE if there are too * many classes. */ public Output(ModelBuilder b) { _isSupervised = b.isSupervised(); if (b.error_count() > 0) throw new IllegalArgumentException(b.validationErrors()); // Capture the data "shape" the model is valid on setNames(b._train != null ? b._train.names() : new String[0]); _domains = b._train != null ? b._train.domains() : new String[0][]; _origNames = b._origNames; _origDomains = b._origDomains; _hasOffset = b.hasOffsetCol(); _hasWeights = b.hasWeightCol(); _hasFold = b.hasFoldCol(); _distribution = b._distribution; _priorClassDist = b._priorClassDist; assert(_job==null); // only set after job completion } /** Returns number of input features (OK for most supervised methods, need to override for unsupervised!) */ public int nfeatures() { return _names.length - (_hasOffset?1:0) - (_hasWeights?1:0) - (_hasFold?1:0) - (isSupervised()?1:0); } /** List of all the associated ModelMetrics objects, so we can delete them * when we delete this model. */ Key[] _model_metrics = new Key[0]; /** Job info: final status (canceled, crashed), build time */ public Job _job; /** * Training set metrics obtained during model training */ public ModelMetrics _training_metrics; /** * Validation set metrics obtained during model training (if a validation data set was specified) */ public ModelMetrics _validation_metrics; /** * Cross-Validation metrics obtained during model training */ public ModelMetrics _cross_validation_metrics; /** * Summary of cross-validation metrics of all k-fold models */ public TwoDimTable _cross_validation_metrics_summary; /** * User-facing model summary - Display model type, complexity, size and other useful stats */ public TwoDimTable _model_summary; /** * User-facing model scoring history - 2D table with modeling accuracy as a function of time/trees/epochs/iterations, etc. */ public TwoDimTable _scoring_history; public double[] _distribution; public double[] _modelClassDist; public double[] _priorClassDist; protected boolean _isSupervised; public boolean isSupervised() { return _isSupervised; } /** The name of the response column (which is always the last column). */ protected final boolean _hasOffset; // weights and offset are kept at designated position in the names array protected final boolean _hasWeights;// only need to know if we have them protected final boolean _hasFold;// only need to know if we have them public boolean hasOffset () { return _hasOffset;} public boolean hasWeights () { return _hasWeights;} public boolean hasFold () { return _hasFold;} public String responseName() { return isSupervised()?_names[responseIdx()]:null;} public String weightsName () { return _hasWeights ?_names[weightsIdx()]:null;} public String offsetName () { return _hasOffset ?_names[offsetIdx()]:null;} public String foldName () { return _hasFold ?_names[foldIdx()]:null;} public String[] interactions() { return null; } // Vec layout is [c1,c2,...,cn,w?,o?,r], cn are predictor cols, r is response, w and o are weights and offset, both are optional public int weightsIdx() { if(!_hasWeights) return -1; return _names.length - (isSupervised()?1:0) - (hasOffset()?1:0) - 1 - (hasFold()?1:0); } public int offsetIdx() { if(!_hasOffset) return -1; return _names.length - (isSupervised()?1:0) - (hasFold()?1:0) - 1; } public int foldIdx() { if(!_hasFold) return -1; return _names.length - (isSupervised()?1:0) - 1; } public int responseIdx() { if(!isSupervised()) return -1; return _names.length-1; } /** Names of levels for a categorical response column. */ public String[] classNames() { if (_domains == null || _domains.length == 0 || !isSupervised()) return null; return _domains[_domains.length - 1]; } /** Is this model a classification model? (v. a regression or clustering model) */ public boolean isClassifier() { return isSupervised() && nclasses() > 1; } /** Is this model a binomial classification model? (v. a regression or clustering model) */ public boolean isBinomialClassifier() { return isSupervised() && nclasses() == 2; } /** Number of classes in the response column if it is categorical and the model is supervised. */ public int nclasses() { String cns[] = classNames(); return cns == null ? 1 : cns.length; } // Note: some algorithms MUST redefine this method to return other model categories public ModelCategory getModelCategory() { if (isSupervised()) return (isClassifier() ? (nclasses() > 2 ? ModelCategory.Multinomial : ModelCategory.Binomial) : ModelCategory.Regression); return ModelCategory.Unknown; } public boolean isAutoencoder() { return false; } // Override in DeepLearning and so on. public synchronized void clearModelMetrics() { _model_metrics = new Key[0]; } public synchronized Key<ModelMetrics>[] getModelMetrics() { return Arrays.copyOf(_model_metrics, _model_metrics.length); } protected long checksum_impl() { return (null == _names ? 13 : Arrays.hashCode(_names)) * (null == _domains ? 17 : Arrays.deepHashCode(_domains)) * getModelCategory().ordinal(); } public void printTwoDimTables(StringBuilder sb, Object o) { for (Field f : Weaver.getWovenFields(o.getClass())) { Class<?> c = f.getType(); if (c.isAssignableFrom(TwoDimTable.class)) { try { TwoDimTable t = (TwoDimTable) f.get(this); f.setAccessible(true); if (t != null) sb.append(t.toString(1,false /*don't print the full table if too long*/)); } catch (IllegalAccessException e) { e.printStackTrace(); } } } } @Override public String toString() { StringBuilder sb = new StringBuilder(); if (_training_metrics!=null) sb.append(_training_metrics.toString()); if (_validation_metrics!=null) sb.append(_validation_metrics.toString()); if (_cross_validation_metrics!=null) sb.append(_cross_validation_metrics.toString()); printTwoDimTables(sb, this); return sb.toString(); } } // Output protected String[][] scoringDomains() { return _output._domains; } public ModelMetrics addMetrics(ModelMetrics mm) { return addModelMetrics(mm); } public abstract ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain); /** Full constructor */ public Model(Key<M> selfKey, P parms, O output) { super(selfKey); assert parms != null; _parms = parms; _output = output; // Output won't be set if we're assert output != null; if (_output != null) _output.startClock(); _dist = isSupervised() && _output.nclasses() == 1 ? new Distribution(_parms) : null; } /** * Deviance of given distribution function at predicted value f * @param w observation weight * @param y (actual) response * @param f (predicted) response in original response space * @return value of gradient */ public double deviance(double w, double y, double f) { return _dist.deviance(w, y, f); } public ScoringInfo[] scoring_history() { return scoringInfo; } /** * Fill a ScoringInfo with data from the ModelMetrics for this model. * @param scoringInfo */ public void fillScoringInfo(ScoringInfo scoringInfo) { scoringInfo.is_classification = this._output.isClassifier(); scoringInfo.is_autoencoder = _output.isAutoencoder(); scoringInfo.scored_train = new ScoreKeeper(this._output._training_metrics); scoringInfo.scored_valid = new ScoreKeeper(this._output._validation_metrics); scoringInfo.scored_xval = new ScoreKeeper(this._output._cross_validation_metrics); scoringInfo.validation = _output._validation_metrics != null; scoringInfo.cross_validation = _output._cross_validation_metrics != null; if (this._output.isBinomialClassifier()) { scoringInfo.training_AUC = this._output._training_metrics == null ? null: ((ModelMetricsBinomial)this._output._training_metrics)._auc; scoringInfo.validation_AUC = this._output._validation_metrics == null ? null : ((ModelMetricsBinomial)this._output._validation_metrics)._auc; } } // return the most up-to-date model metrics public ScoringInfo last_scored() { return scoringInfo == null ? null : scoringInfo[scoringInfo.length-1]; } // Lower is better public float loss() { switch (_parms._stopping_metric) { case MSE: return (float) mse(); case MAE: return (float) mae(); case RMSLE: return (float) rmsle(); case logloss: return (float) logloss(); case deviance: return (float) deviance(); case misclassification: return (float) classification_error(); case AUC: return (float)(1-auc()); case mean_per_class_error: return (float)mean_per_class_error(); case lift_top_group: return (float)lift_top_group(); case AUTO: default: return (float) (_output.isClassifier() ? logloss() : _output.isAutoencoder() ? mse() : deviance()); } } // loss() public int compareTo(M o) { if (o._output.isClassifier() != _output.isClassifier()) throw new UnsupportedOperationException("Cannot compare classifier against regressor."); if (o._output.isClassifier()) { if (o._output.nclasses() != _output.nclasses()) throw new UnsupportedOperationException("Cannot compare models with different number of classes."); } return (loss() < o.loss() ? -1 : loss() > o.loss() ? 1 : 0); } public double classification_error() { if (scoringInfo != null) return last_scored().cross_validation ? last_scored().scored_xval._classError : last_scored().validation ? last_scored().scored_valid._classError : last_scored().scored_train._classError; ModelMetrics mm = _output._cross_validation_metrics != null ? _output._cross_validation_metrics : _output._validation_metrics != null ? _output._validation_metrics : _output._training_metrics; if (mm == null) return Double.NaN; if (mm instanceof ModelMetricsBinomial) { return ((ModelMetricsBinomial)mm)._auc.defaultErr(); } else if (mm instanceof ModelMetricsMultinomial) { return ((ModelMetricsMultinomial)mm)._cm.err(); } return Double.NaN; } public double mse() { if (scoringInfo != null) return last_scored().cross_validation ? last_scored().scored_xval._mse : last_scored().validation ? last_scored().scored_valid._mse : last_scored().scored_train._mse; ModelMetrics mm = _output._cross_validation_metrics != null ? _output._cross_validation_metrics : _output._validation_metrics != null ? _output._validation_metrics : _output._training_metrics; if (mm == null) return Double.NaN; return mm.mse(); } public double mae() { if (scoringInfo != null) return last_scored().cross_validation ? last_scored().scored_xval._mae : last_scored().validation ? last_scored().scored_valid._mae : last_scored().scored_train._mae; ModelMetrics mm = _output._cross_validation_metrics != null ? _output._cross_validation_metrics : _output._validation_metrics != null ? _output._validation_metrics : _output._training_metrics; if (mm == null) return Double.NaN; return ((ModelMetricsRegression)mm).mae(); } public double rmsle() { if (scoringInfo != null) return last_scored().cross_validation ? last_scored().scored_xval._rmsle : last_scored().validation ? last_scored().scored_valid._rmsle : last_scored().scored_train._rmsle; ModelMetrics mm = _output._cross_validation_metrics != null ? _output._cross_validation_metrics : _output._validation_metrics != null ? _output._validation_metrics : _output._training_metrics; if (mm == null) return Double.NaN; return ((ModelMetricsRegression)mm).rmsle(); } public double auc() { if (scoringInfo != null) return last_scored().cross_validation ? last_scored().scored_xval._AUC : last_scored().validation ? last_scored().scored_valid._AUC : last_scored().scored_train._AUC; ModelMetrics mm = _output._cross_validation_metrics != null ? _output._cross_validation_metrics : _output._validation_metrics != null ? _output._validation_metrics : _output._training_metrics; if (mm == null) return Double.NaN; return ((ModelMetricsBinomial)mm)._auc._auc; } public double deviance() { if (scoringInfo != null) return last_scored().cross_validation ? last_scored().scored_xval._mean_residual_deviance: last_scored().validation ? last_scored().scored_valid._mean_residual_deviance : last_scored().scored_train._mean_residual_deviance; ModelMetrics mm = _output._cross_validation_metrics != null ? _output._cross_validation_metrics : _output._validation_metrics != null ? _output._validation_metrics : _output._training_metrics; if (mm == null) return Double.NaN; return ((ModelMetricsRegression)mm)._mean_residual_deviance; } public double logloss() { if (scoringInfo != null) return last_scored().cross_validation ? last_scored().scored_xval._logloss : last_scored().validation ? last_scored().scored_valid._logloss : last_scored().scored_train._logloss; ModelMetrics mm = _output._cross_validation_metrics != null ? _output._cross_validation_metrics : _output._validation_metrics != null ? _output._validation_metrics : _output._training_metrics; if (mm == null) return Double.NaN; if (mm instanceof ModelMetricsBinomial) { return ((ModelMetricsBinomial)mm).logloss(); } else if (mm instanceof ModelMetricsMultinomial) { return ((ModelMetricsMultinomial)mm).logloss(); } return Double.NaN; } public double mean_per_class_error() { if (scoringInfo != null) return last_scored().cross_validation ? last_scored().scored_xval._mean_per_class_error : last_scored().validation ? last_scored().scored_valid._mean_per_class_error : last_scored().scored_train._mean_per_class_error; ModelMetrics mm = _output._cross_validation_metrics != null ? _output._cross_validation_metrics : _output._validation_metrics != null ? _output._validation_metrics : _output._training_metrics; if (mm == null) return Double.NaN; if (mm instanceof ModelMetricsBinomial) { return ((ModelMetricsBinomial)mm).mean_per_class_error(); } else if (mm instanceof ModelMetricsMultinomial) { return ((ModelMetricsMultinomial)mm).mean_per_class_error(); } return Double.NaN; } public double lift_top_group() { if (scoringInfo != null) return last_scored().cross_validation ? last_scored().scored_xval._lift : last_scored().validation ? last_scored().scored_valid._lift : last_scored().scored_train._lift; ModelMetrics mm = _output._cross_validation_metrics != null ? _output._cross_validation_metrics : _output._validation_metrics != null ? _output._validation_metrics : _output._training_metrics; if (mm == null) return Double.NaN; if (mm instanceof ModelMetricsBinomial) { GainsLift gl = ((ModelMetricsBinomial)mm)._gainsLift; if (gl != null && gl.response_rates != null && gl.response_rates.length > 0) { return gl.response_rates[0] / gl.avg_response_rate; } } return Double.NaN; } /** Adapt a Test/Validation Frame to be compatible for a Training Frame. The * intention here is that ModelBuilders can assume the test set has the same * count of columns, and within each factor column the same set of * same-numbered levels. Extra levels are renumbered past those in the * Train set but will still be present in the Test set, thus requiring * range-checking. * * This routine is used before model building (with no Model made yet) to * check for compatible datasets, and also used to prepare a large dataset * for scoring (with a Model). * * Adaption does the following things: * - Remove any "extra" Vecs appearing only in the test and not the train * - Insert any "missing" Vecs appearing only in the train and not the test * with all NAs ({@see missingColumnsType}). This will issue a warning, * and if the "expensive" flag is false won't actually make the column * replacement column but instead will bail-out on the whole adaption (but * will continue looking for more warnings). * - If all columns are missing, issue an error. * - Renumber matching cat levels to match the Train levels; this might make * "holes" in the Test set cat levels, if some are not in the Test set. * - Extra Test levels are renumbered past the end of the Train set, hence * the train and test levels match up to all the train levels; there might * be extra Test levels past that. * - For all mis-matched levels, issue a warning. * * The {@code test} frame is updated in-place to be compatible, by altering * the names and Vecs; make a defensive copy if you do not want it modified. * There is a fast-path cutout if the test set is already compatible. Since * the test-set is conditionally modifed with extra CategoricalWrappedVec optionally * added it is recommended to use a Scope enter/exit to track Vec lifetimes. * * @param test Testing Frame, updated in-place * @param expensive Try hard to adapt; this might involve the creation of * whole Vecs and thus get expensive. If {@code false}, then only adapt if * no warnings and errors; otherwise just the messages are produced. * Created Vecs have to be deleted by the caller (e.g. Scope.enter/exit). * @return Array of warnings; zero length (never null) for no warnings. * Throws {@code IllegalArgumentException} if no columns are in common, or * if any factor column has no levels in common. */ public String[] adaptTestForTrain(Frame test, boolean expensive, boolean computeMetrics) { return adaptTestForTrain( test, _output._origNames, _output._origDomains, _output._names, _output._domains, _parms, expensive, computeMetrics, _output.interactions(), getToEigenVec(), _toDelete, false); } /** * @param test Frame to be adapted * @param origNames Training column names before categorical column encoding - can be the same as names * @param origDomains Training column levels before categorical column encoding - can be the same as domains * @param names Training column names * @param domains Training column levels * @param parms Model parameters * @param expensive Whether to actually do the hard work * @param computeMetrics Whether metrics can be (and should be) computed * @param interactions Column names to create pairwise interactions with * @param catEncoded Whether the categorical columns of the test frame were already transformed via categorical_encoding */ public static String[] adaptTestForTrain(Frame test, String[] origNames, String[][] origDomains, String[] names, String[][] domains, Parameters parms, boolean expensive, boolean computeMetrics, String[] interactions, ToEigenVec tev, IcedHashMap<Key, String> toDelete, boolean catEncoded) throws IllegalArgumentException { String[] msg = new String[0]; if (test == null) return msg; if (catEncoded && origNames==null) return msg; // test frame matches the training frame (after categorical encoding, if applicable) String[][] tdomains = test.domains(); if (names == test._names && domains == tdomains || (Arrays.equals(names, test._names) && Arrays.deepEquals(domains, tdomains)) ) return msg; String[] backupNames = names; String[][] backupDomains = domains; final String weights = parms._weights_column; final String offset = parms._offset_column; final String fold = parms._fold_column; final String response = parms._response_column; // whether we need to be careful with categorical encoding - the test frame could be either in original state or in encoded state final boolean checkCategoricals = parms._categorical_encoding == Parameters.CategoricalEncodingScheme.OneHotExplicit || parms._categorical_encoding == Parameters.CategoricalEncodingScheme.Eigen || parms._categorical_encoding == Parameters.CategoricalEncodingScheme.Binary; // test frame matches the user-given frame (before categorical encoding, if applicable) if( checkCategoricals && origNames != null ) { boolean match = Arrays.equals(origNames, test._names); if (!match) { match = true; // In case the test set has extra columns not in the training set - check that all original pre-encoding columns are available in the test set // We could be lenient here and fill missing columns with NA, but then it gets difficult to decide whether this frame is pre/post encoding, if a certain fraction of columns mismatch... for (String s : origNames) { match &= ArrayUtils.contains(test.names(), s); if (!match) break; } } // still have work to do below, make sure we set the names/domains to the original user-given values such that we can do the int->enum mapping and cat. encoding below (from scratch) if (match) { names = origNames; domains = origDomains; } } // create the interactions now and bolt them on to the front of the test Frame if( null!=interactions ) { int[] interactionIndexes = new int[interactions.length]; for(int i=0;i<interactions.length;++i) interactionIndexes[i] = test.find(interactions[i]); test.add(makeInteractions(test, false, InteractionPair.generatePairwiseInteractionsFromList(interactionIndexes), true, true, false)); } final double missing = parms.missingColumnsType(); // Build the validation set to be compatible with the training set. // Toss out extra columns, complain about missing ones, remap categoricals ArrayList<String> msgs = new ArrayList<>(); Vec vvecs[] = new Vec[names.length]; int good = 0; // Any matching column names, at all? int convNaN = 0; // count of columns that were replaced with NA for( int i=0; i<names.length; i++ ) { Vec vec = test.vec(names[i]); // Search in the given validation set boolean isResponse = response != null && names[i].equals(response); boolean isWeights = weights != null && names[i].equals(weights); boolean isOffset = offset != null && names[i].equals(offset); boolean isFold = fold != null && names[i].equals(fold); // If a training set column is missing in the test set, complain (if it's ok, fill in with NAs (or 0s if it's a fold-column)) if (vec == null) { if (isResponse && computeMetrics) throw new IllegalArgumentException("Test/Validation dataset is missing response column '" + response + "'"); else if (isOffset) throw new IllegalArgumentException("Test/Validation dataset is missing offset column '" + offset + "'"); else if (isWeights && computeMetrics) { if (expensive) { vec = test.anyVec().makeCon(1); toDelete.put(vec._key, "adapted missing vectors"); msgs.add(H2O.technote(1, "Test/Validation dataset is missing weights column '" + names[i] + "' (needed because a response was found and metrics are to be computed): substituting in a column of 1s")); } } else if (expensive) { String str = "Test/Validation dataset is missing column '" + names[i] + "': substituting in a column of " + (isFold ? 0 : missing); vec = test.anyVec().makeCon(isFold ? 0 : missing); toDelete.put(vec._key, "adapted missing vectors"); if (!isFold) convNaN++; msgs.add(str); } } if( vec != null ) { // I have a column with a matching name if( domains[i] != null ) { // Model expects an categorical if (vec.isString()) vec = VecUtils.stringToCategorical(vec); //turn a String column into a categorical column (we don't delete the original vec here) if( expensive && vec.domain() != domains[i] && !Arrays.equals(vec.domain(),domains[i]) ) { // Result needs to be the same categorical CategoricalWrappedVec evec; try { evec = vec.adaptTo(domains[i]); // Convert to categorical or throw IAE toDelete.put(evec._key, "categorically adapted vec"); } catch( NumberFormatException nfe ) { throw new IllegalArgumentException("Test/Validation dataset has a non-categorical column '"+names[i]+"' which is categorical in the training data"); } String[] ds = evec.domain(); assert ds != null && ds.length >= domains[i].length; if( isResponse && vec.domain() != null && ds.length == domains[i].length+vec.domain().length ) throw new IllegalArgumentException("Test/Validation dataset has a categorical response column '"+names[i]+"' with no levels in common with the model"); if (ds.length > domains[i].length) msgs.add("Test/Validation dataset column '" + names[i] + "' has levels not trained on: " + Arrays.toString(Arrays.copyOfRange(ds, domains[i].length, ds.length))); vec = evec; } } else if(vec.isCategorical()) { if (parms._categorical_encoding == Parameters.CategoricalEncodingScheme.LabelEncoder) { Vec evec = vec.toNumericVec(); toDelete.put(evec._key, "label encoded vec"); vec = evec; } else { throw new IllegalArgumentException("Test/Validation dataset has categorical column '" + names[i] + "' which is real-valued in the training data"); } } good++; // Assumed compatible; not checking e.g. Strings vs UUID } vvecs[i] = vec; } if( good == names.length || (response != null && test.find(response) == -1 && good == names.length - 1) ) // Only update if got something for all columns test.restructure(names,vvecs,good); boolean haveCategoricalPredictors = false; if (expensive && checkCategoricals && !catEncoded) { for (int i=0; i<test.numCols(); ++i) { if (test.names()[i].equals(response)) continue; if (test.names()[i].equals(weights)) continue; if (test.names()[i].equals(offset)) continue; if (test.names()[i].equals(fold)) continue; // either the column of the test set is categorical (could be a numeric col that's already turned into a factor) if (test.vec(i).cardinality() > 0) { haveCategoricalPredictors = true; break; } // or a equally named column of the training set is categorical, but the test column isn't (e.g., numeric column provided to be converted to a factor) int whichCol = ArrayUtils.find(names, test.name(i)); if (whichCol >= 0 && domains[whichCol] != null) { haveCategoricalPredictors = true; break; } } } // check if we first need to expand categoricals before calling this method again if (expensive && !catEncoded && haveCategoricalPredictors) { Frame updated = categoricalEncoder(test, new String[]{weights, offset, fold, response}, parms._categorical_encoding, tev); toDelete.put(updated._key, "categorically encoded frame"); test.restructure(updated.names(), updated.vecs()); //updated in place String[] msg2 = adaptTestForTrain(test, origNames, origDomains, backupNames, backupDomains, parms, expensive, computeMetrics, interactions, tev, toDelete, true /*catEncoded*/); msgs.addAll(Arrays.asList(msg2)); return msgs.toArray(new String[msgs.size()]); } if( good == convNaN ) throw new IllegalArgumentException("Test/Validation dataset has no columns in common with the training set"); return msgs.toArray(new String[msgs.size()]); } /** * Bulk score the frame, and auto-name the resulting predictions frame. * @see #score(Frame, String) * @param fr frame which should be scored * @return A new frame containing a predicted values. For classification it * contains a column with prediction and distribution for all * response classes. For regression it contains only one column with * predicted values. * @throws IllegalArgumentException */ public Frame score(Frame fr) throws IllegalArgumentException { return score(fr, null, null, true); } /** Bulk score the frame {@code fr}, producing a Frame result; the 1st * Vec is the predicted class, the remaining Vecs are the probability * distributions. For Regression (single-class) models, the 1st and only * Vec is the prediction value. The result is in the DKV; caller is * responsible for deleting. * * @param fr frame which should be scored * @return A new frame containing a predicted values. For classification it * contains a column with prediction and distribution for all * response classes. For regression it contains only one column with * predicted values. * @throws IllegalArgumentException */ public Frame score(Frame fr, String destination_key) throws IllegalArgumentException { return score(fr, destination_key, null, true); } public Frame score(Frame fr, String destination_key, Job j) throws IllegalArgumentException { return score(fr, destination_key, j, true); } public Frame score(Frame fr, String destination_key, Job j, boolean computeMetrics) throws IllegalArgumentException { Frame adaptFr = new Frame(fr); computeMetrics = computeMetrics && (!isSupervised() || (adaptFr.vec(_output.responseName()) != null && !adaptFr.vec(_output.responseName()).isBad())); String[] msg = adaptTestForTrain(adaptFr,true, computeMetrics); // Adapt if (msg.length > 0) { for (String s : msg) Log.warn(s); } Frame output = predictScoreImpl(fr, adaptFr, destination_key, j, computeMetrics); // Predict & Score // Log modest confusion matrices Vec predicted = output.vecs()[0]; // Modeled/predicted response String mdomain[] = predicted.domain(); // Domain of predictions (union of test and train) // Output is in the model's domain, but needs to be mapped to the scored // dataset's domain. if(_output.isClassifier() && computeMetrics) { /* if (false) { assert(mdomain != null); // label must be categorical ModelMetrics mm = ModelMetrics.getFromDKV(this,fr); ConfusionMatrix cm = mm.cm(); if (cm != null && cm._domain != null) //don't print table for regression if( cm._cm.length < _parms._max_confusion_matrix_size ) { // Print size limitation Log.info(cm.table().toString(1)); } if (mm.hr() != null) { Log.info(getHitRatioTable(mm.hr())); } } */ Vec actual = fr.vec(_output.responseName()); if( actual != null ) { // Predict does not have an actual, scoring does String sdomain[] = actual.domain(); // Scored/test domain; can be null if (sdomain != null && mdomain != sdomain && !Arrays.equals(mdomain, sdomain)) output.replace(0, new CategoricalWrappedVec(actual.group().addVec(), actual._rowLayout, sdomain, predicted._key)); } } cleanup_adapt(adaptFr, fr); return output; } /** * Compute the deviances for each observation * @param valid Validation Frame (must contain the response) * @param predictions Predictions made by the model * @param outputName Name of the output frame * @return Frame containing 1 column with the per-row deviances */ public Frame computeDeviances(Frame valid, Frame predictions, String outputName) { assert (_parms._response_column!=null) : "response column can't be null"; assert valid.find(_parms._response_column)>=0 : "validation frame must contain a response column"; predictions.add(_parms._response_column, valid.vec(_parms._response_column)); if (valid.find(_parms._weights_column)>=0) predictions.add(_parms._weights_column, valid.vec(_parms._weights_column)); final int respIdx=predictions.find(_parms._response_column); final int weightIdx=predictions.find(_parms._weights_column); final Distribution myDist = _dist == null ? null : IcedUtils.deepCopy(_dist); if (myDist != null && myDist.distribution == DistributionFamily.huber) { myDist.setHuberDelta(hex.ModelMetricsRegression.computeHuberDelta( valid.vec(_parms._response_column), //actual predictions.vec(0), //predictions valid.vec(_parms._weights_column), //weight _parms._huber_alpha)); } return new MRTask() { @Override public void map(Chunk[] cs, NewChunk[] nc) { Chunk weight = weightIdx>=0 ? cs[weightIdx] : new C0DChunk(1, cs[0]._len); Chunk response = cs[respIdx]; for (int i=0;i<cs[0]._len;++i) { double w=weight.atd(i); double y=response.atd(i); if (_output.nclasses()==1) { //regression - deviance double f=cs[0].atd(i); if (myDist!=null && myDist.distribution == DistributionFamily.huber) { nc[0].addNum(myDist.deviance(w, y, f)); //use above custom huber delta for this dataset } else { nc[0].addNum(deviance(w, y, f)); } } else { int iact=(int)y; double err = iact < _output.nclasses() ? 1-cs[1+iact].atd(i) : 1; nc[0].addNum(w*MathUtils.logloss(err)); } } } }.doAll(Vec.T_NUM, predictions).outputFrame(Key.<Frame>make(outputName), new String[]{"deviance"}, null); } // Remove temp keys. TODO: Really should use Scope but Scope does not // currently allow nested-key-keepers. static protected void cleanup_adapt( Frame adaptFr, Frame fr ) { Key[] keys = adaptFr.keys(); for( int i=0; i<keys.length; i++ ) if( fr.find(keys[i]) == -1 ) //only delete vecs that aren't shared keys[i].remove(); DKV.remove(adaptFr._key); //delete the frame header } protected String [] makeScoringNames(){ final int nc = _output.nclasses(); final int ncols = nc==1?1:nc+1; // Regression has 1 predict col; classification also has class distribution String [] names = new String[ncols]; names[0] = "predict"; for(int i = 1; i < names.length; ++i) { names[i] = _output.classNames()[i - 1]; // turn integer class labels such as 0, 1, etc. into p0, p1, etc. try { Integer.valueOf(names[i]); names[i] = "p" + names[i]; } catch (Throwable t) { // do nothing, non-integer names are fine already } } return names; } /** Allow subclasses to define their own BigScore class. */ protected BigScore makeBigScoreTask(String[][] domains, String[] names , Frame adaptFrm, boolean computeMetrics, boolean makePrediction, Job j) { return new BigScore(domains[0], names != null ? names.length : 0, adaptFrm.means(), _output.hasWeights() && adaptFrm.find(_output.weightsName()) >= 0, computeMetrics, makePrediction, j); //.doAll(names.length, Vec.T_NUM, adaptFrm); } /** Score an already adapted frame. Returns a new Frame with new result * vectors, all in the DKV. Caller responsible for deleting. Input is * already adapted to the Model's domain, so the output is also. Also * computes the metrics for this frame. * * @param adaptFrm Already adapted frame * @param computeMetrics * @return A Frame containing the prediction column, and class distribution */ protected Frame predictScoreImpl(Frame fr, Frame adaptFrm, String destination_key, Job j, boolean computeMetrics) { // Build up the names & domains. String[] names = makeScoringNames(); String[][] domains = new String[names.length][]; domains[0] = names.length == 1 ? null : !computeMetrics ? _output._domains[_output._domains.length-1] : adaptFrm.lastVec().domain(); // Score the dataset, building the class distribution & predictions BigScore bs = makeBigScoreTask(domains, names, adaptFrm, computeMetrics, true, j).doAll(names.length, Vec.T_NUM, adaptFrm); if (computeMetrics) bs._mb.makeModelMetrics(this, fr, adaptFrm, bs.outputFrame()); Frame predictFr = bs.outputFrame(Key.<Frame>make(destination_key), names, domains); return postProcessPredictions(predictFr); } protected Frame postProcessPredictions(Frame predictFr) { // nothing by default return predictFr; } /** Score an already adapted frame. Returns a MetricBuilder that can be used to make a model metrics. * @param adaptFrm Already adapted frame * @return MetricBuilder */ protected ModelMetrics.MetricBuilder scoreMetrics(Frame adaptFrm) { final boolean computeMetrics = (!isSupervised() || (adaptFrm.vec(_output.responseName()) != null && !adaptFrm.vec(_output.responseName()).isBad())); // Build up the names & domains. //String[] names = makeScoringNames(); String[][] domains = new String[1][]; domains[0] = _output.nclasses() == 1 ? null : !computeMetrics ? _output._domains[_output._domains.length-1] : adaptFrm.lastVec().domain(); // Score the dataset, building the class distribution & predictions BigScore bs = makeBigScoreTask(domains, null, adaptFrm, computeMetrics, false, null).doAll(adaptFrm); return bs._mb; } protected class BigScore extends MRTask<BigScore> { final protected String[] _domain; // Prediction domain; union of test and train classes final protected int _npredcols; // Number of columns in prediction; nclasses+1 - can be less than the prediction domain public ModelMetrics.MetricBuilder _mb; final double[] _mean; // Column means of test frame final public boolean _computeMetrics; // Column means of test frame final public boolean _hasWeights; final public boolean _makePreds; final public Job _j; public BigScore( String[] domain, int ncols, double[] mean, boolean testHasWeights, boolean computeMetrics, boolean makePreds, Job j) { _j = j; _domain = domain; _npredcols = ncols; _mean = mean; _computeMetrics = computeMetrics; _makePreds = makePreds; if(_output._hasWeights && _computeMetrics && !testHasWeights) throw new IllegalArgumentException("Missing weights when computing validation metrics."); _hasWeights = testHasWeights; } @Override public void map( Chunk chks[], NewChunk cpreds[] ) { if (isCancelled() || _j != null && _j.stop_requested()) return; Chunk weightsChunk = _hasWeights && _computeMetrics ? chks[_output.weightsIdx()] : null; Chunk offsetChunk = _output.hasOffset() ? chks[_output.offsetIdx()] : null; Chunk responseChunk = null; double [] tmp = new double[_output.nfeatures()]; float [] actual = null; _mb = Model.this.makeMetricBuilder(_domain); if (_computeMetrics) { if (isSupervised()) { actual = new float[1]; responseChunk = chks[_output.responseIdx()]; } else actual = new float[chks.length]; } double[] preds = _mb._work; // Sized for the union of test and train classes int len = chks[0]._len; try { setupBigScorePredict(); for (int row = 0; row < len; row++) { double weight = weightsChunk != null ? weightsChunk.atd(row) : 1; if (weight == 0) { if (_makePreds) { for (int c = 0; c < _npredcols; c++) // Output predictions; sized for train only (excludes extra test classes) cpreds[c].addNum(0); } continue; } double offset = offsetChunk != null ? offsetChunk.atd(row) : 0; double[] p = score0(chks, weight, offset, row, tmp, preds); if (_computeMetrics) { if (isSupervised()) { actual[0] = (float) responseChunk.atd(row); } else { for (int i = 0; i < actual.length; ++i) actual[i] = (float) data(chks, row, i); } _mb.perRow(preds, actual, weight, offset, Model.this); } if (_makePreds) { for (int c = 0; c < _npredcols; c++) // Output predictions; sized for train only (excludes extra test classes) cpreds[c].addNum(p[c]); } } } finally { closeBigScorePredict(); } if ( _j != null) _j.update(1); } @Override public void reduce( BigScore bs ) { if(_mb != null )_mb.reduce(bs._mb); } @Override protected void postGlobal() { if(_mb != null)_mb.postGlobal(); } } protected void setupBigScorePredict() {} protected void closeBigScorePredict() {} // OVerride this if your model needs data preprocessing (on the fly standardization, NA handling) protected double data(Chunk[] chks, int row, int col) { return chks[col].atd(row); } /** Bulk scoring API for one row. Chunks are all compatible with the model, * and expect the last Chunks are for the final distribution and prediction. * Default method is to just load the data into the tmp array, then call * subclass scoring logic. */ public double[] score0( Chunk chks[], int row_in_chunk, double[] tmp, double[] preds ) { return score0(chks, 1, 0, row_in_chunk, tmp, preds); } public double[] score0( Chunk chks[], double weight, double offset, int row_in_chunk, double[] tmp, double[] preds ) { assert(_output.nfeatures() == tmp.length); for( int i=0; i< tmp.length; i++ ) tmp[i] = chks[i].atd(row_in_chunk); double [] scored = score0(tmp, preds, weight, offset); if(isSupervised()) { // Correct probabilities obtained from training on oversampled data back to original distribution // C.f. http://gking.harvard.edu/files/0s.pdf Eq.(27) if( _output.isClassifier()) { if (_parms._balance_classes) GenModel.correctProbabilities(scored, _output._priorClassDist, _output._modelClassDist); //assign label at the very end (after potentially correcting probabilities) scored[0] = hex.genmodel.GenModel.getPrediction(scored, _output._priorClassDist, tmp, defaultThreshold()); } } return scored; } /** Subclasses implement the scoring logic. The data is pre-loaded into a * re-used temp array, in the order the model expects. The predictions are * loaded into the re-used temp array, which is also returned. */ protected abstract double[] score0(double data[/*ncols*/], double preds[/*nclasses+1*/]); /**Override scoring logic for models that handle weight/offset**/ protected double[] score0(double data[/*ncols*/], double preds[/*nclasses+1*/], double weight, double offset) { assert (weight == 1 && offset == 0) : "Override this method for non-trivial weight/offset!"; return score0(data, preds); } // Version where the user has just ponied-up an array of data to be scored. // Data must be in proper order. Handy for JUnit tests. public double score(double[] data){ return ArrayUtils.maxIndex(score0(data, new double[_output.nclasses()])); } @Override protected Futures remove_impl( Futures fs ) { if (_output._model_metrics != null) for( Key k : _output._model_metrics ) k.remove(fs); cleanUp(_toDelete); return super.remove_impl(fs); } /** Write out K/V pairs, in this case model metrics. */ @Override protected AutoBuffer writeAll_impl(AutoBuffer ab) { if (_output._model_metrics != null) for( Key k : _output._model_metrics ) ab.putKey(k); return super.writeAll_impl(ab); } @Override protected Keyed readAll_impl(AutoBuffer ab, Futures fs) { if (_output._model_metrics != null) for( Key k : _output._model_metrics ) ab.getKey(k,fs); // Load model metrics return super.readAll_impl(ab,fs); } @Override protected long checksum_impl() { return _parms.checksum_impl() * _output.checksum_impl(); } /** * Override this in models that support serialization into the MOJO format. * @return a class that inherits from ModelMojoWriter */ public ModelMojoWriter getMojo() { throw H2O.unimpl("MOJO format is not available for " + _parms.fullName() + " models."); } // ========================================================================== /** Return a String which is a valid Java program representing a class that * implements the Model. The Java is of the form: * <pre> * class UUIDxxxxModel { * public static final String NAMES[] = { ....column names... } * public static final String DOMAINS[][] = { ....domain names... } * // Pass in data in a double[], pre-aligned to the Model's requirements. * // Jam predictions into the preds[] array; preds[0] is reserved for the * // main prediction (class for classifiers or value for regression), * // and remaining columns hold a probability distribution for classifiers. * double[] predict( double data[], double preds[] ); * double[] map( HashMap < String,Double > row, double data[] ); * // Does the mapping lookup for every row, no allocation * double[] predict( HashMap < String,Double > row, double data[], double preds[] ); * // Allocates a double[] for every row * double[] predict( HashMap < String,Double > row, double preds[] ); * // Allocates a double[] and a double[] for every row * double[] predict( HashMap < String,Double > row ); * } * </pre> */ public final String toJava(boolean preview, boolean verboseCode) { // 32k buffer by default ByteArrayOutputStream os = new ByteArrayOutputStream(Short.MAX_VALUE); // We do not need to close BAOS /* ignore returned stream */ toJava(os, preview, verboseCode); return os.toString(); } public final SBPrintStream toJava(OutputStream os, boolean preview, boolean verboseCode) { if (preview /* && toJavaCheckTooBig() */) { os = new LineLimitOutputStreamWrapper(os, 1000); } return toJava(new SBPrintStream(os), preview, verboseCode); } protected SBPrintStream toJava(SBPrintStream sb, boolean isGeneratingPreview, boolean verboseCode) { CodeGeneratorPipeline fileCtx = new CodeGeneratorPipeline(); // preserve file context String modelName = JCodeGen.toJavaId(_key.toString()); // HEADER sb.p("/*").nl(); sb.p(" Licensed under the Apache License, Version 2.0").nl(); sb.p(" http://www.apache.org/licenses/LICENSE-2.0.html").nl(); sb.nl(); sb.p(" AUTOGENERATED BY H2O at ").p(new DateTime().toString()).nl(); sb.p(" ").p(H2O.ABV.projectVersion()).nl(); sb.p(" ").nl(); sb.p(" Standalone prediction code with sample test data for ").p(this.getClass().getSimpleName()).p(" named ").p(modelName) .nl(); sb.nl(); sb.p(" How to download, compile and execute:").nl(); sb.p(" mkdir tmpdir").nl(); sb.p(" cd tmpdir").nl(); sb.p(" curl http:/").p(H2O.SELF.toString()).p("/3/h2o-genmodel.jar > h2o-genmodel.jar").nl(); sb.p(" curl http:/").p(H2O.SELF.toString()).p("/3/Models.java/").pobj(_key).p(" > ").p(modelName).p(".java").nl(); sb.p(" javac -cp h2o-genmodel.jar -J-Xmx2g -J-XX:MaxPermSize=128m ").p(modelName).p(".java").nl(); // Intentionally disabled since there is no main method in generated code // sb.p("// java -cp h2o-genmodel.jar:. -Xmx2g -XX:MaxPermSize=256m -XX:ReservedCodeCacheSize=256m ").p(modelName).nl(); sb.nl(); sb.p(" (Note: Try java argument -XX:+PrintCompilation to show runtime JIT compiler behavior.)").nl(); if (_parms._offset_column != null) { sb.nl(); sb.nl(); sb.nl(); sb.p(" NOTE: Java model export does not support offset_column.").nl(); sb.nl(); Log.warn("Java model export does not support offset_column."); } if (isGeneratingPreview && toJavaCheckTooBig()) { sb.nl(); sb.nl(); sb.nl(); sb.p(" NOTE: Java model is too large to preview, please download as shown above.").nl(); sb.nl(); return sb; } sb.p("*/").nl(); sb.p("import java.util.Map;").nl(); sb.p("import hex.genmodel.GenModel;").nl(); sb.p("import hex.genmodel.annotations.ModelPojo;").nl(); sb.nl(); String algo = this.getClass().getSimpleName().toLowerCase().replace("model", ""); sb.p("@ModelPojo(name=\"").p(modelName).p("\", algorithm=\"").p(algo).p("\")").nl(); sb.p("public class ").p(modelName).p(" extends GenModel {").nl().ii(1); sb.ip("public hex.ModelCategory getModelCategory() { return hex.ModelCategory." + _output .getModelCategory() + "; }").nl(); toJavaInit(sb, fileCtx).nl(); toJavaNAMES(sb, fileCtx); toJavaNCLASSES(sb); toJavaDOMAINS(sb, fileCtx); toJavaPROB(sb); toJavaSuper(modelName, sb); // sb.p(" public String getUUID() { return Long.toString("+checksum()+"L); }").nl(); toJavaPredict(sb, fileCtx, verboseCode); sb.p("}").nl().di(1); fileCtx.generate(sb); // Append file context sb.nl(); return sb; } /** Generate implementation for super class. */ protected SBPrintStream toJavaSuper(String modelName, SBPrintStream sb) { return sb.nl().ip("public " + modelName + "() { super(NAMES,DOMAINS); }").nl(); } private SBPrintStream toJavaNAMES(SBPrintStream sb, CodeGeneratorPipeline fileCtx) { final String modelName = JCodeGen.toJavaId(_key.toString()); final String namesHolderClassName = "NamesHolder_"+modelName; sb.i().p("// ").p("Names of columns used by model.").nl(); sb.i().p("public static final String[] NAMES = "+namesHolderClassName+".VALUES;").nl(); // Generate class which fills the names into array fileCtx.add(new CodeGenerator() { @Override public void generate(JCodeSB out) { out.i().p("// The class representing training column names").nl(); JCodeGen.toClassWithArray(out, null, namesHolderClassName, Arrays.copyOf(_output._names, _output.nfeatures())); } }); return sb; } protected SBPrintStream toJavaNCLASSES(SBPrintStream sb ) { return _output.isClassifier() ? JCodeGen.toStaticVar(sb, "NCLASSES", _output.nclasses(), "Number of output classes included in training data response column.") : sb; } private SBPrintStream toJavaDOMAINS(SBPrintStream sb, CodeGeneratorPipeline fileCtx) { String modelName = JCodeGen.toJavaId(_key.toString()); sb.nl(); sb.ip("// Column domains. The last array contains domain of response column.").nl(); sb.ip("public static final String[][] DOMAINS = new String[][] {").nl(); String [][] domains = scoringDomains(); for (int i=0; i< domains.length; i++) { final int idx = i; final String[] dom = domains[i]; final String colInfoClazz = modelName+"_ColInfo_"+i; sb.i(1).p("/* ").p(_output._names[i]).p(" */ "); if (dom != null) sb.p(colInfoClazz).p(".VALUES"); else sb.p("null"); if (i!=domains.length-1) sb.p(','); sb.nl(); // Right now do not generate the class representing column // since it does not hold any interesting information except String array holding domain if (dom != null) { fileCtx.add(new CodeGenerator() { @Override public void generate(JCodeSB out) { out.ip("// The class representing column ").p(_output._names[idx]).nl(); JCodeGen.toClassWithArray(out, null, colInfoClazz, dom); } } ); } } return sb.ip("};").nl(); } protected SBPrintStream toJavaPROB(SBPrintStream sb) { if(isSupervised()) { JCodeGen.toStaticVar(sb, "PRIOR_CLASS_DISTRIB", _output._priorClassDist, "Prior class distribution"); JCodeGen.toStaticVar(sb, "MODEL_CLASS_DISTRIB", _output._modelClassDist, "Class distribution used for model building"); } return sb; } protected boolean toJavaCheckTooBig() { Log.warn("toJavaCheckTooBig must be overridden for this model type to render it in the browser"); return true; } // Override in subclasses to provide some top-level model-specific goodness protected SBPrintStream toJavaInit(SBPrintStream sb, CodeGeneratorPipeline fileContext) { return sb; } // Override in subclasses to provide some inside 'predict' call goodness // Method returns code which should be appended into generated top level class after // predict method. protected void toJavaPredictBody(SBPrintStream body, CodeGeneratorPipeline classCtx, CodeGeneratorPipeline fileCtx, boolean verboseCode) { throw new IllegalArgumentException("This model type does not support conversion to Java"); } // Wrapper around the main predict call, including the signature and return value private SBPrintStream toJavaPredict(SBPrintStream ccsb, CodeGeneratorPipeline fileCtx, boolean verboseCode) { // ccsb = classContext ccsb.nl(); ccsb.ip("// Pass in data in a double[], pre-aligned to the Model's requirements.").nl(); ccsb.ip("// Jam predictions into the preds[] array; preds[0] is reserved for the").nl(); ccsb.ip("// main prediction (class for classifiers or value for regression),").nl(); ccsb.ip("// and remaining columns hold a probability distribution for classifiers.").nl(); ccsb.ip("public final double[] score0( double[] data, double[] preds ) {").nl(); CodeGeneratorPipeline classCtx = new CodeGeneratorPipeline(); //new SB().ii(1); toJavaPredictBody(ccsb.ii(1), classCtx, fileCtx, verboseCode); ccsb.ip("return preds;").nl(); ccsb.di(1).ip("}").nl(); // Output class context classCtx.generate(ccsb.ii(1)); ccsb.di(1); return ccsb; } // Convenience method for testing: build Java, convert it to a class & // execute it: compare the results of the new class's (JIT'd) scoring with // the built-in (interpreted) scoring on this dataset. Returns true if all // is well, false is there are any mismatches. Throws if there is any error // (typically an AssertionError or unable to compile the POJO). public boolean testJavaScoring(Frame data, Frame model_predictions, double rel_epsilon) { return testJavaScoring(data, model_predictions, rel_epsilon, 1e-15, 0.1); } public boolean testJavaScoring(Frame data, Frame model_predictions, double rel_epsilon, double abs_epsilon) { return testJavaScoring(data, model_predictions, rel_epsilon, abs_epsilon, 0.1); } public boolean testJavaScoring(Frame data, Frame model_predictions, double rel_epsilon, double abs_epsilon, double fraction) { ModelBuilder mb = ModelBuilder.make(_parms.algoName().toLowerCase(), null, null); boolean havePojo = mb.havePojo(); boolean haveMojo = mb.haveMojo(); Random rnd = RandomUtils.getRNG(data.byteSize()); assert data.numRows() == model_predictions.numRows(); Frame fr = new Frame(data); boolean computeMetrics = data.vec(_output.responseName()) != null && !data.vec(_output.responseName()).isBad(); try { String[] warns = adaptTestForTrain(fr,true, computeMetrics); if( warns.length > 0 ) System.err.println(Arrays.toString(warns)); // Output is in the model's domain, but needs to be mapped to the scored // dataset's domain. int[] omap = null; if( _output.isClassifier() ) { Vec actual = fr.vec(_output.responseName()); String[] sdomain = actual == null ? null : actual.domain(); // Scored/test domain; can be null String[] mdomain = model_predictions.vec(0).domain(); // Domain of predictions (union of test and train) if( sdomain != null && !Arrays.equals(mdomain, sdomain)) { omap = CategoricalWrappedVec.computeMap(mdomain,sdomain); // Map from model-domain to scoring-domain } } String modelName = JCodeGen.toJavaId(_key.toString()); boolean preview = false; GenModel genmodel = null; Vec[] dvecs = fr.vecs(); Vec[] pvecs = model_predictions.vecs(); double[] features = null; int num_errors = 0; int num_total = 0; // First try internal POJO via fast double[] API if (havePojo) { try { String java_text = toJava(preview, true); Class clz = JCodeGen.compile(modelName,java_text); genmodel = (GenModel)clz.newInstance(); } catch (Exception e) { e.printStackTrace(); throw H2O.fail("Internal POJO compilation failed",e); } features = MemoryManager.malloc8d(genmodel._names.length); double[] predictions = MemoryManager.malloc8d(genmodel.nclasses() + 1); // Compare predictions, counting mis-predicts for (int row=0; row<fr.numRows(); row++) { // For all rows, single-threaded if (rnd.nextDouble() >= fraction) continue; num_total++; // Native Java API for (int col = 0; col < features.length; col++) // Build feature set features[col] = dvecs[col].at(row); genmodel.score0(features, predictions); // POJO predictions for (int col = _output.isClassifier() ? 1 : 0; col < pvecs.length; col++) { // Compare predictions double d = pvecs[col].at(row); // Load internal scoring predictions if (col == 0 && omap != null) d = omap[(int) d]; // map categorical response to scoring domain if (!MathUtils.compare(predictions[col], d, abs_epsilon, rel_epsilon)) { if (num_errors++ < 10) System.err.println("Predictions mismatch, row " + row + ", col " + model_predictions._names[col] + ", internal prediction=" + d + ", POJO prediction=" + predictions[col]); break; } } } } // EasyPredict API with POJO and/or MOJO for (int i = 0; i < 2; ++i) { if (i == 0 && !havePojo) continue; if (i == 1 && !haveMojo) continue; if (i == 1) { // MOJO final String filename = modelName + ".zip"; StreamingSchema ss = new StreamingSchema(getMojo(), filename); try { FileOutputStream os = new FileOutputStream(ss.getFilename()); ss.getStreamWriter().writeTo(os); os.close(); genmodel = MojoModel.load(filename); features = MemoryManager.malloc8d(genmodel._names.length); } catch (IOException e1) { e1.printStackTrace(); throw H2O.fail("Internal MOJO loading failed", e1); } finally { boolean deleted = new File(filename).delete(); if (!deleted) Log.warn("Failed to delete the file"); } } EasyPredictModelWrapper epmw = new EasyPredictModelWrapper( new EasyPredictModelWrapper.Config().setModel(genmodel).setConvertUnknownCategoricalLevelsToNa(true) ); RowData rowData = new RowData(); BufferedString bStr = new BufferedString(); for (int row = 0; row < fr.numRows(); row++) { // For all rows, single-threaded if (rnd.nextDouble() >= fraction) continue; if (genmodel.getModelCategory() == ModelCategory.AutoEncoder) continue; // Generate input row for (int col = 0; col < features.length; col++) { if (dvecs[col].isString()) { rowData.put(genmodel._names[col], dvecs[col].atStr(bStr, row).toString()); } else { double val = dvecs[col].at(row); rowData.put( genmodel._names[col], genmodel._domains[col] == null ? (Double) val : Double.isNaN(val) ? val // missing categorical values are kept as NaN, the score0 logic passes it on to bitSetContains() : (int) val < genmodel._domains[col].length ? genmodel._domains[col][(int) val] : "UnknownLevel"); //unseen levels are treated as such } } // Make a prediction AbstractPrediction p; try { p = epmw.predict(rowData); } catch (PredictException e) { num_errors++; if (num_errors < 20) { System.err.println("EasyPredict threw an exception when predicting row " + rowData); e.printStackTrace(); } continue; } // Convert model predictions and "internal" predictions into the same shape double[] expected_preds = new double[pvecs.length]; double[] actual_preds = new double[pvecs.length]; for (int col = 0; col < pvecs.length; col++) { // Compare predictions double d = pvecs[col].at(row); // Load internal scoring predictions if (col == 0 && omap != null) d = omap[(int) d]; // map categorical response to scoring domain double d2 = Double.NaN; switch (genmodel.getModelCategory()) { case Clustering: d2 = ((ClusteringModelPrediction) p).cluster; break; case Regression: d2 = ((RegressionModelPrediction) p).value; break; case Binomial: BinomialModelPrediction bmp = (BinomialModelPrediction) p; d2 = (col == 0) ? bmp.labelIndex : bmp.classProbabilities[col - 1]; break; case Multinomial: MultinomialModelPrediction mmp = (MultinomialModelPrediction) p; d2 = (col == 0) ? mmp.labelIndex : mmp.classProbabilities[col - 1]; break; case DimReduction: d2 = ((DimReductionModelPrediction) p).dimensions[col]; break; } expected_preds[col] = d; actual_preds[col] = d2; } // Verify the correctness of the prediction num_total++; for (int col = genmodel.isClassifier() ? 1 : 0; col < pvecs.length; col++) { if (!MathUtils.compare(actual_preds[col], expected_preds[col], abs_epsilon, rel_epsilon)) { num_errors++; if (num_errors < 20) { System.err.println( (i == 0 ? "POJO" : "MOJO") + " EasyPredict Predictions mismatch for row " + rowData); System.err.println(" Expected predictions: " + Arrays.toString(expected_preds)); System.err.println(" Actual predictions: " + Arrays.toString(actual_preds)); } break; } } } } if (num_errors != 0) System.err.println("Number of errors: " + num_errors + (num_errors > 20 ? " (only first 20 are shown)": "") + " out of " + num_total + " rows tested."); return num_errors == 0; } finally { cleanup_adapt(fr, data); // Remove temp keys. } } public void deleteCrossValidationModels( ) { if (_output._cross_validation_models != null) { for (Key k : _output._cross_validation_models) { Model m = DKV.getGet(k); if (m!=null) m.delete(); //delete all subparts } } } @Override public String toString() { return _output.toString(); } /** Model stream writer - output Java code representation of model. */ public class JavaModelStreamWriter extends StreamWriter { /** Show only preview */ private final boolean preview; public JavaModelStreamWriter(boolean preview) { this.preview = preview; } @Override public void writeTo(OutputStream os) { toJava(os, preview, true); } } @Override public Class<KeyV3.ModelKeyV3> makeSchema() { return KeyV3.ModelKeyV3.class; } public static Frame makeInteractions(Frame fr, boolean valid, InteractionPair[] interactions, boolean useAllFactorLevels, boolean skipMissing, boolean standardize) { Vec anyTrainVec = fr.anyVec(); Vec[] interactionVecs = new Vec[interactions.length]; String[] interactionNames = new String[interactions.length]; int idx = 0; for (InteractionPair ip : interactions) { interactionNames[idx] = fr.name(ip._v1) + "_" + fr.name(ip._v2); InteractionWrappedVec iwv =new InteractionWrappedVec(anyTrainVec.group().addVec(), anyTrainVec._rowLayout, ip._v1Enums, ip._v2Enums, useAllFactorLevels, skipMissing, standardize, fr.vec(ip._v1)._key, fr.vec(ip._v2)._key); // if(!valid) ip.setDomain(iwv.domain()); interactionVecs[idx++] = iwv; } return new Frame(interactionNames, interactionVecs); } public static InteractionWrappedVec[] makeInteractions(Frame fr, InteractionPair[] interactions, boolean useAllFactorLevels, boolean skipMissing, boolean standardize) { Vec anyTrainVec = fr.anyVec(); InteractionWrappedVec[] interactionVecs = new InteractionWrappedVec[interactions.length]; int idx = 0; for (InteractionPair ip : interactions) interactionVecs[idx++] = new InteractionWrappedVec(anyTrainVec.group().addVec(), anyTrainVec._rowLayout, ip._v1Enums, ip._v2Enums, useAllFactorLevels, skipMissing, standardize, fr.vec(ip._v1)._key, fr.vec(ip._v2)._key); return interactionVecs; } public static InteractionWrappedVec makeInteraction(Frame fr, InteractionPair ip, boolean useAllFactorLevels, boolean skipMissing, boolean standardize) { Vec anyVec = fr.anyVec(); return new InteractionWrappedVec(anyVec.group().addVec(), anyVec._rowLayout, ip._v1Enums, ip._v2Enums, useAllFactorLevels, skipMissing, standardize, fr.vec(ip._v1)._key, fr.vec(ip._v2)._key); } /** * This class represents a pair of interacting columns plus some additional data * about specific enums to be interacted when the vecs are categorical. The question * naturally arises why not just use something like an ArrayList of int[2] (as is done, * for example, in the Interaction/CreateInteraction classes) and the answer essentially * boils down a desire to specify these specific levels. * * Another difference with the CreateInteractions class: * 1. do not interact on NA (someLvl_NA and NA_somLvl are actual NAs) * this does not appear here, but in the InteractionWrappedVec class * TODO: refactor the CreateInteractions to be useful here and in InteractionWrappedVec */ public static class InteractionPair extends Iced { public int vecIdx; private int _v1,_v2; private String[] _domain; // not null for enum-enum interactions private String[] _v1Enums; private String[] _v2Enums; private int _hash; private InteractionPair() {} private InteractionPair(int v1, int v2, String[] v1Enums, String[] v2Enums) { _v1=v1;_v2=v2;_v1Enums=v1Enums;_v2Enums=v2Enums; // hash is column ints; Item 9 p.47 of Effective Java _hash=17; _hash = 31*_hash + _v1; _hash = 31*_hash + _v2; if( _v1Enums==null ) _hash = 31*_hash; else for( String s:_v1Enums ) _hash = 31*_hash + s.hashCode(); if( _v2Enums==null ) _hash = 31*_hash; else for( String s:_v2Enums ) _hash = 31*_hash + s.hashCode(); } /** * Generate all pairwise combinations of ints in the range [from,to). * @param from Start index * @param to End index (exclusive) * @return An array of interaction pairs. */ public static InteractionPair[] generatePairwiseInteractions(int from, int to) { if( 1==(to-from) ) throw new IllegalArgumentException("Illegal range of values, must be greater than a single value. Got: " + from + "<" + to); InteractionPair[] res = new InteractionPair[ ((to-from-1)*(to-from)) >> 1]; // n*(n+1) / 2 int idx=0; for(int i=from;i<to;++i) for(int j=i+1;j<to;++j) res[idx++] = new InteractionPair(i,j,null,null); return res; } /** * Generate all pairwise combinations of the arguments. * @param indexes An array of column indices. * @return An array of interaction pairs */ public static InteractionPair[] generatePairwiseInteractionsFromList(int... indexes) { if( null==indexes ) return null; if( indexes.length < 2 ) { if( indexes.length==1 && indexes[0]==-1 ) return null; throw new IllegalArgumentException("Must supply 2 or more columns."); } InteractionPair[] res = new InteractionPair[ (indexes.length-1)*(indexes.length)>>1]; // n*(n+1) / 2 int idx=0; for(int i=0;i<indexes.length;++i) for(int j=i+1;j<indexes.length;++j) res[idx++] = new InteractionPair(indexes[i],indexes[j],null,null); return res; } /** * Set the domain; computed in an MRTask over the two categorical vectors that make * up this interaction pair * @param dom The domain retrieved by the CombineDomainTask in InteractionWrappedVec */ public void setDomain(String[] dom) { _domain=dom; } /** * Check to see if any of the vecIdx values is the desired value. */ public static int isInteraction(int i, InteractionPair[] ips) { int idx = 0; for (InteractionPair ip: ips) { if (i == ip.vecIdx) return idx; else idx++; } return -1; } // parser stuff private int _p; private String _str; public static InteractionPair[] read(String interaction) { String[] interactions=interaction.split("\n"); HashSet<InteractionPair> res = new HashSet<>(); for (String i: interactions) res.addAll(new InteractionPair().parse(i)); return res.toArray(new InteractionPair[res.size()]); } private HashSet<InteractionPair> parse(String i) { // v1[E8,E9]:v2,v3,v8,v90,v128[E1,E22] _p=0; _str=i; HashSet<InteractionPair> res=new HashSet<>(); int v1 = parseNum(); // parse the first int String[] v1Enums=parseEnums(); // shared if( i.charAt(_p)!=':' || _p>=i.length() ) throw new IllegalArgumentException("Error"); while( _p++<i.length() ) { int v2=parseNum(); String[] v2Enums=parseEnums(); if( v1 == v2 ) continue; // don't interact on self! res.add(new InteractionPair(v1,v2,v1Enums,v2Enums)); } return res; } private int parseNum() { int start=_p++; while( _p<_str.length() && '0' <= _str.charAt(_p) && _str.charAt(_p) <= '9') _p++; try { return Integer.valueOf(_str.substring(start,_p)); } catch(NumberFormatException ex) { throw new IllegalArgumentException("No number could be parsed. Interaction: " + _str); } } private String[] parseEnums() { if( _p>=_str.length() || _str.charAt(_p)!='[' ) return null; ArrayList<String> enums = new ArrayList<>(); while( _str.charAt(_p++)!=']' ) { int start=_p++; while(_str.charAt(_p)!=',' && _str.charAt(_p)!=']') _p++; enums.add(_str.substring(start,_p)); } return enums.toArray(new String[enums.size()]); } @Override public int hashCode() { return _hash; } @Override public String toString() { return _v1+(_v1Enums==null?"":Arrays.toString(_v1Enums))+":"+_v2+(_v2Enums==null?"":Arrays.toString(_v2Enums)); } @Override public boolean equals( Object o ) { boolean res = o instanceof InteractionPair; if (res) { InteractionPair ip = (InteractionPair) o; return (_v1 == ip._v1) && (_v2 == ip._v2) && Arrays.equals(_v1Enums, ip._v1Enums) && Arrays.equals(_v2Enums, ip._v2Enums); } return false; } } }