package hex.mli.loco; import hex.Model; import hex.ModelCategory; import water.MRTask; import water.ParallelizationTask; import water.exceptions.H2OIllegalArgumentException; import water.fvec.Chunk; import water.fvec.Frame; import water.fvec.NewChunk; import water.fvec.Vec; import water.DKV; import water.H2O; import water.util.Log; import water.Iced; import water.Job; import water.rapids.ast.prims.reducers.AstMedian; import hex.quantile.QuantileModel.CombineMethod; import water.Key; /** * Leave One Covariate Out (LOCO) * * Calculates row-wise variable importance's by re-scoring a trained supervised model and measuring the impact of setting * each variable to missing or it’s most central value(mean or median & mode for categorical's) * */ public class LeaveOneCovarOut extends Iced { /** * Conduct Leave One Covariate Out (LOCO) given a model, frame, job, and replacement value * @param m Supervised H2O model * @param fr H2O Frame to score * @param job Job to keep track of in terms of progress * @param replaceVal Value to replace column by when conducting LOCO ("mean" or "median"). Default behavior is to setting to NA * @param frameKey Key of final Leave One Covariate Out(LOCO) Frame. * @return An H2OFrame displaying the base prediction (model scored with all predictors) and the difference in predictions * when variables are dropped/replaced. The difference displayed is the base prediction substracted from * the new prediction (when a variable is dropped/replaced with mean/median/mode) for binomial classification * and regression problems. For multinomial problems, the sum of the absolute value of differences across classes * is calculated per column dropped/replaced. */ public static Frame leaveOneCovarOut(Model m, Frame fr, Job job, String replaceVal, Key frameKey){ Frame locoAnalysisFrame = new Frame(); //Set up initial LOCO frame if(m._output.getModelCategory() != ModelCategory.Multinomial) { locoAnalysisFrame.add("base_pred", getBasepredictions(m, fr)[0]); //If not multinomial, then predictions are one column } else { locoAnalysisFrame.add(new Frame(getBasepredictions(m, fr))); locoAnalysisFrame._names[0] = "base_pred"; } String[] predictors = m._output._names; //Get predictors LeaveOneCovariateOutDriver[] tasks = new LeaveOneCovariateOutDriver[predictors.length-1]; //Set up tasks. Last column is not needed as its the response for(int i = 0; i < tasks.length; i++){ tasks[i] = new LeaveOneCovariateOutDriver(locoAnalysisFrame,fr,m,predictors[i],replaceVal); } ParallelizationTask locoCollector = new ParallelizationTask<>(tasks, job); long start = System.currentTimeMillis(); Log.info("Starting Leave One Covariate Out (LOCO) analysis for model " + m._key + " and frame " + fr._key); H2O.submitTask(locoCollector).join(); //If multinomial, then we need to remove predicted probabilities for each class. We only want the final class predicted as the first column if(m._output.getModelCategory() == ModelCategory.Multinomial){ int[] colsToRemove = new int[locoAnalysisFrame.numCols()-1]; for(int i =0; i<colsToRemove.length; i++){ colsToRemove[i] = i+1; } locoAnalysisFrame.remove(colsToRemove); } for (int i = 0; i < tasks.length; i++) { locoAnalysisFrame.add("rc_" + tasks[i]._predictor, tasks[i]._result[0]); } Log.info("Finished Leave One Covariate Out (LOCO) analysis for model " + m._key + " and frame " + fr._key + " in " + (System.currentTimeMillis()-start)/1000. + " seconds for " + (predictors.length-1) + " columns"); //Put final frame into DKV if(frameKey != null) { locoAnalysisFrame._key = frameKey; DKV.put(locoAnalysisFrame._key, locoAnalysisFrame); } else{ locoAnalysisFrame._key = Key.make("loco_"+fr._key.toString() + "_" + m._key.toString()); DKV.put(locoAnalysisFrame._key,locoAnalysisFrame); } return locoAnalysisFrame; } public static class LeaveOneCovariateOutDriver extends H2O.H2OCountedCompleter<LeaveOneCovariateOutDriver>{ private final Frame _locoFrame; private final Frame _frame; private final Model _model; private final String _predictor; private final String _replaceVal; Vec[] _result; public LeaveOneCovariateOutDriver(Frame locoFrame, Frame fr, Model m, String predictor, String replaceVal){ _locoFrame = locoFrame; _frame = fr; _model = m; _predictor = predictor; _replaceVal = replaceVal; } @Override public void compute2() { if(_model._output.getModelCategory() == ModelCategory.Multinomial){ Vec[] predTmp = getNewPredictions(_model,_frame,_predictor,_replaceVal); Frame tmpFrame = new Frame().add(_locoFrame).add(new Frame(predTmp)); _result = new MultiDiffTask(_model._output.nclasses()).doAll(Vec.T_NUM, tmpFrame).outputFrame().vecs(); for (Vec v : predTmp) v.remove(); //Clean up DKV, otherwise we will get leaked keys } else { _result = getNewPredictions(_model, _frame, _predictor,_replaceVal); new DiffTask().doAll(_locoFrame.vec(0), _result[0]); } Log.info("Completed Leave One Covariate Out (LOCO) Analysis for column: " + _predictor); tryComplete(); } } /** * Get base predictions given a model and frame. "Base predictions" are predictions based on all features in the * model * @param m An H2O supervised model * @param fr A Frame to score on * @return An array of Vecs containing predictions */ private static Vec[] getBasepredictions(Model m, Frame fr){ Frame basePredsFr = m.score(fr,null,null,false); if(m._output.getModelCategory() == ModelCategory.Binomial) { Vec basePreds = basePredsFr.remove(2); basePredsFr.delete(); return new Vec[] {basePreds}; }else if(m._output.getModelCategory() == ModelCategory.Multinomial){ Vec[] basePredsVecs = basePredsFr.vecs(); DKV.remove(basePredsFr._key); return basePredsVecs; } else { Vec basePreds = basePredsFr.remove(0); basePredsFr.delete(); return new Vec[] {basePreds}; } } /** * Get new predictions based on dropping/replacing a column with mean, median, or mode * @param m An H2O supervised model * @param fr A Frame to score on * @param colToDrop Column to modify/drop before prediction * @param valToReplace Value to replace colToDrop by (Default is null) * @return */ private static Vec[] getNewPredictions(Model m, Frame fr, String colToDrop, String valToReplace) { Frame workerFrame = new Frame(fr); Vec vecToReplace = fr.vec(colToDrop); Vec replacementVec = null; if(valToReplace == null){ replacementVec = vecToReplace.makeCon(Double.NaN); } else if(valToReplace.equals("mean")){ if(vecToReplace.isCategorical()){ Vec tmpVec = vecToReplace.makeCon(vecToReplace.mode()); replacementVec = tmpVec.toCategoricalVec(); //Can only get mode for categoricals tmpVec.remove(); } else { replacementVec = vecToReplace.makeCon(vecToReplace.mean()); } } else if(valToReplace.equals("median")){ if(vecToReplace.isCategorical()){ Vec tmpVec = vecToReplace.makeCon(vecToReplace.mode()); replacementVec = tmpVec.toCategoricalVec(); //Can only get mode for categoricals tmpVec.remove(); } else { Frame tmpFr = new Frame(vecToReplace); double median = AstMedian.median(tmpFr, CombineMethod.AVERAGE); replacementVec = vecToReplace.makeCon(median); } } else { throw new H2OIllegalArgumentException("Invalid value to replace columns in LOCO. Got " + valToReplace); } int vecToDropIdx = fr.find(colToDrop); workerFrame.replace(vecToDropIdx,replacementVec); DKV.put(workerFrame); Frame modifiedPredictionsFr = m.score(workerFrame,null,null,false); try { if (m._output.getModelCategory() == ModelCategory.Binomial) { Vec modifiedPrediction = modifiedPredictionsFr.remove(2); modifiedPredictionsFr.delete(); return new Vec[] {modifiedPrediction}; } else if(m._output.getModelCategory() == ModelCategory.Multinomial){ Vec[] vecs = modifiedPredictionsFr.vecs(); DKV.remove(modifiedPredictionsFr._key); return vecs; } else { Vec modifiedPrediction = modifiedPredictionsFr.remove(0); modifiedPredictionsFr.delete(); return new Vec[] {modifiedPrediction}; } } finally{ DKV.remove(workerFrame._key); replacementVec.remove(); } } private static class DiffTask extends MRTask<DiffTask>{ @Override public void map(Chunk[] c) { Chunk _basePred = c[0]; for(int chnk = 1; chnk < c.length; chnk++){ for(int row = 0; row < c[0]._len; row++){ c[chnk].set(row, c[chnk].atd(row) - _basePred.atd(row)); } } } } private static class MultiDiffTask extends MRTask<MultiDiffTask>{ private final int _numClasses; public MultiDiffTask(int numClasses){ _numClasses = numClasses; } @Override public void map(Chunk[] cs, NewChunk nc) { for (int i = 0; i < cs[0]._len; i++) { double d = 0; for (int j = 1; j < _numClasses+1; j++) { double val = cs[j + _numClasses+1].atd(i) - cs[j].atd(i); d += Math.abs(val); } nc.addNum(d); } } } }