package hex; import jsr166y.CountedCompleter; import water.*; import water.api.schemas3.KeyV3; import water.fvec.Frame; import water.fvec.Vec; import water.util.Log; import water.util.TwoDimTable; import java.util.Arrays; /** * Create a Frame from scratch * If randomize = true, then the frame is filled with Random values. */ public class PartialDependence extends Lockable<PartialDependence> { transient final public Job _job; public Key<Model> _model_id; public Key<Frame> _frame_id; public String[] _cols; public int _nbins = 20; public TwoDimTable[] _partial_dependence_data; //OUTPUT public PartialDependence(Key<PartialDependence> dest) { super(dest); _job = new Job<>(dest, PartialDependence.class.getName(), "PartialDependence"); } public Job<PartialDependence> execImpl() { checkSanityAndFillParams(); delete_and_lock(_job); _frame_id.get().write_lock(_job._key); // Don't lock the model since the act of unlocking at the end would // freshen the DKV version, but the live POJO must survive all the way // to be able to delete the model metrics that got added to it. // Note: All threads doing the scoring call model_id.get() and then // update the _model_metrics only on the temporary live object, not in DKV. // At the end, we call model.remove() and we need those model metrics to be // deleted with it, so we must make sure we keep the live POJO alive. _job.start(new PartialDependenceDriver(), _cols.length); return _job; } private void checkSanityAndFillParams() { if (_cols==null) { Model m = _model_id.get(); if (m==null) throw new IllegalArgumentException("Model not found."); if (!m._output.isSupervised() || m._output.nclasses() > 2) throw new IllegalArgumentException("Partial dependence plots are only implemented for regression and binomial classification models"); Frame f = _frame_id.get(); if (f==null) throw new IllegalArgumentException("Frame not found."); if (Model.GetMostImportantFeatures.class.isAssignableFrom(m.getClass())) { _cols = ((Model.GetMostImportantFeatures)m).getMostImportantFeatures(10); if (_cols != null) { Log.info("Selecting the top " + _cols.length + " features from the model's variable importances"); } } } if (_nbins < 2) { throw new IllegalArgumentException("_nbins must be >=2."); } final Frame fr = _frame_id.get(); for (int i = 0; i < _cols.length; ++i) { final String col = _cols[i]; Vec v = fr.vec(col); if (v.isCategorical() && v.cardinality() > _nbins) { throw new IllegalArgumentException("Column " + col + "'s cardinality of " + v.cardinality() + " > nbins of " + _nbins); } } } private class PartialDependenceDriver extends H2O.H2OCountedCompleter<PartialDependenceDriver> { public void compute2() { assert (_job != null); final Frame fr = _frame_id.get(); // loop over PDPs (columns) _partial_dependence_data = new TwoDimTable[_cols.length]; for (int i = 0; i < _cols.length; ++i) { final String col = _cols[i]; Log.debug("Computing partial dependence of model on '" + col + "'."); Vec v = fr.vec(col); int actualbins = _nbins; if (v.isInt() && (v.max() - v.min() + 1) < _nbins) { actualbins = (int) (v.max() - v.min() + 1); } double[] colVals = new double[actualbins]; double delta = (v.max() - v.min()) / (actualbins - 1); if (actualbins == 1) delta = 0; for (int j = 0; j < colVals.length; ++j) { colVals[j] = v.min() + j * delta; } Log.debug("Computing PartialDependence for column " + col + " at the following values: "); Log.debug(Arrays.toString(colVals)); Futures fs = new Futures(); final double meanResponse[] = new double[colVals.length]; final double stddevResponse[] = new double[colVals.length]; final boolean cat = fr.vec(col).isCategorical(); // loop over column values (fill one PartialDependence) for (int k = 0; k < colVals.length; ++k) { final double value = colVals[k]; final int which = k; H2O.H2OCountedCompleter pdp = new H2O.H2OCountedCompleter() { @Override public void compute2() { Frame fr = _frame_id.get(); Frame test = new Frame(fr.names(), fr.vecs()); Vec orig = test.remove(col); Vec cons = orig.makeCon(value); if (cat) cons.setDomain(fr.vec(col).domain()); test.add(col, cons); Frame preds = null; try { preds = _model_id.get().score(test, Key.make().toString(), _job, false); if (_model_id.get()._output.nclasses() == 2) { meanResponse[which] = preds.vec(2).mean(); stddevResponse[which] = preds.vec(2).sigma(); } else if (_model_id.get()._output.nclasses() == 1) { meanResponse[which] = preds.vec(0).mean(); stddevResponse[which] = preds.vec(0).sigma(); } else throw H2O.unimpl(); } finally { if (preds != null) preds.remove(); } cons.remove(); tryComplete(); } }; fs.add(H2O.submitTask(pdp)); } fs.blockForPending(); /* // baseline double baselineMeanResponse; Frame preds = null; try { preds = _model_id.get().score(_frame_id.get()); if (_model_id.get()._output.nclasses() == 2) { baselineMeanResponse = preds.vec(2).mean(); } else if (_model_id.get()._output.nclasses() == 1) { baselineMeanResponse = preds.vec(0).mean(); } else throw H2O.unimpl(); } finally { if (preds!=null) preds.remove(); } */ // Log.info("Baseline: " + baselineMeanResponse); // Log.info(Arrays.toString(meanResponse)); _partial_dependence_data[i] = new TwoDimTable("PartialDependence", ("Partial Dependence Plot of model " + _model_id + " on column '" + _cols[i] + "'"), new String[actualbins], new String[]{_cols[i], "mean_response", "stddev_response"}, new String[]{cat ? "string" : "double", "double", "double"}, new String[]{cat ? "%s" : "%5f", "%5f", "%5f"}, null); for (int j = 0; j < meanResponse.length; ++j) { if (fr.vec(col).isCategorical()) { _partial_dependence_data[i].set(j, 0, fr.vec(col).domain()[(int) colVals[j]]); } else { _partial_dependence_data[i].set(j, 0, colVals[j]); } _partial_dependence_data[i].set(j, 1, meanResponse[j]); _partial_dependence_data[i].set(j, 2, stddevResponse[j]); } _job.update(1); update(_job); if (_job.stop_requested()) break; } tryComplete(); } @Override public void onCompletion(CountedCompleter caller) { _frame_id.get().unlock(_job._key); unlock(_job); } @Override public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) { _frame_id.get().unlock(_job._key); unlock(_job); return true; } } @Override public Class<KeyV3.PartialDependenceKeyV3> makeSchema() { return KeyV3.PartialDependenceKeyV3.class; } }