package hex.singlenoderf; import dontweave.gson.JsonArray; import dontweave.gson.JsonElement; import dontweave.gson.JsonObject; import dontweave.gson.JsonPrimitive; import hex.ConfusionMatrix; import hex.VarImp; import hex.gbm.DTree; import hex.gbm.DTree.TreeModel.TreeStats; import water.*; import water.api.*; import water.api.Request.API; import water.fvec.Chunk; import water.fvec.Frame; import water.fvec.NewChunk; import water.fvec.Vec; import water.util.Counter; import water.util.ModelUtils; import java.util.Arrays; import java.util.Random; import static hex.singlenoderf.VariableImportance.asVotes; public class SpeeDRFModel extends Model implements Job.Progress { static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code. /** * Model Parameters */ /* Number of features these trees are built for */ int features; /* Sampling strategy used for model */ Sampling.Strategy sampling_strategy; /* Key name */ String src_key; @API(help = " Sampling rate used when building trees.") float sample; @API(help = "Strata sampling rate used for local-node strata-sampling") float[] strata_samples; @API(help = "Number of split features defined by user.") int mtry; /* Number of computed split features per node */ int[] node_split_features; @API(help = "Number of keys the model expects to be built for it.") int N; @API(help = "Max depth to grow trees to") int max_depth; @API(help = "All the trees in the model.") Key[] t_keys; /* Local forests produced by nodes */ Key[][] local_forests; /* Errors Per Tree */ long[] errorsPerTree; /* Total time in seconds to produce the model */ long time; /* Is there a validation set?*/ boolean validation; /* Response Min */ int resp_min; /* Class weights */ double[] weights; @API(help = "bin limit") int nbins; /* Raw tree data. for faster classification passes */ transient byte[][] trees; @API(help = "Job key") Key jobKey; /* Destination Key */ Key dest_key; /* Current model status */ String current_status; @API(help = "MSE by tree") double[] errs; /* Statistic Type */ Tree.StatType statType; @API(help = "Test Key") Key testKey; /* Out of bag error estimate */ boolean oobee; /* Seed */ protected long zeed; /* Variable Importance */ boolean importance; /* Final Confusion Matrix */ CMTask.CMFinal confusion; @API(help = "Confusion Matrices") ConfusionMatrix[] cms; /* Confusion Matrix */ long[][] cm; @API(help = "Tree Statistics") TreeStats treeStats; @API(help = "cmDomain") String[] cmDomain; @API(help = "AUC") public AUCData validAUC; @API(help = "Variable Importance") public VarImp varimp; /* Regression or Classification */ boolean regression; /* Score each iteration? */ boolean score_each; @API(help = "CV Error") public double cv_error; @API(help = "Verbose Mode") public boolean verbose; @API(help = "Verbose Output") public String[] verbose_output; @API(help = "Use non-local data") public boolean useNonLocal; @API(help = "Dtree keys") public Key[/*ntree*/][/*nclass*/] dtreeKeys; @API(help = "DTree Model") public SpeeDRFModel_DTree dtreeTreeModel = null; @API(help = "score_pojo boolean") public boolean score_pojo; private float _ss; private float _cnt; /** * Extra helper variables. */ private transient VariableImportance.TreeMeasures[/*features*/] _treeMeasuresOnOOB; // Tree votes/SSE per individual features on permutated OOB rows private transient VariableImportance.TreeMeasures[/*features*/] _treeMeasuresOnSOOB; public static final String JSON_CONFUSION_KEY = "confusion_key"; public static final String JSON_CM_TYPE = "type"; public static final String JSON_CM_HEADER = "header"; public static final String JSON_CM_MATRIX = "scores"; public static final String JSON_CM_TREES = "used_trees"; public static final String JSON_CM_CLASS_ERR = "classification_error"; public static final String JSON_CM_ROWS = "rows"; public static final String JSON_CM_ROWS_SKIPPED = "rows_skipped"; public static final String JSON_CM_CLASSES_ERRORS = "classes_errors"; @API(help = "Model parameters", json = true) private final SpeeDRF parameters; @Override public final SpeeDRF get_params() { return parameters; } @Override public final Request2 job() { return get_params(); } @Override public final VarImp varimp() { return varimp; } public float[] priordist() { return _priorClassDist; } public float[] modeldist() { return _modelClassDist; } public SpeeDRFModel(Key selfKey, Key dataKey, Frame fr, String[] domain, SpeeDRF params, float[] priorDist) { super(selfKey, dataKey, fr, priorDist); this.dest_key = selfKey; this.parameters = params; score_each = params.score_each_iteration; regression = !(params.classification); } protected SpeeDRFModel(SpeeDRFModel model, double err, ConfusionMatrix cm, VarImp varimp, AUCData auc) { super(model._key,model._dataKey,model._names,model._domains, model._priorClassDist,model._modelClassDist,model.training_start_time,model.training_duration_in_ms); this.features = model.features; this.sampling_strategy = model.sampling_strategy; this.sample = model.sample; this.strata_samples = model.strata_samples; this.mtry = model.mtry; this.node_split_features = model.node_split_features; this.N = model.N; this.max_depth = model.max_depth; this.t_keys = model.t_keys; this.local_forests = model.local_forests; this.time = model.time; this.weights = model.weights; this.nbins = model.nbins; this.trees = model.trees; this.jobKey = model.jobKey; this.dest_key = model.dest_key; this.current_status = model.current_status; this.errs = model.errs; this.statType = model.statType; this.testKey = model.testKey; this.oobee = model.oobee; this.zeed = model.zeed; this.importance = model.importance; this.confusion = model.confusion; this.cms = Arrays.copyOf(model.cms, model.cms.length+1); this.cms[this.cms.length-1] = cm; this.parameters = model.parameters; this.cm = cm._arr; this.treeStats = model.treeStats; this.cmDomain = model.cmDomain; this.validAUC = auc; this.varimp = varimp; this.regression = model.regression; this.score_each = model.score_each; this.cv_error = err; this.verbose = model.verbose; this.verbose_output = model.verbose_output; this.useNonLocal = model.useNonLocal; this.errorsPerTree = model.errorsPerTree; this.resp_min = model.resp_min; this.validation = model.validation; this.src_key = model.src_key; this.score_pojo = model.score_pojo; } public int treeCount() { return t_keys.length; } public int size() { return t_keys.length; } public int classes() { return nclasses(); } @Override public ConfusionMatrix cm() { return validAUC == null ? cms[cms.length-1] : validAUC.CM(); } private void scoreOnTest(Frame fr, Vec modelResp) { Frame scored = score(fr); water.api.ConfusionMatrix cm = new water.api.ConfusionMatrix(); cm.vactual = fr.lastVec(); cm.vpredict = scored.anyVec(); cm.invoke(); // Regression scoring if (regression) { float mse = (float) cm.mse; errs[errs.length - 1] = mse; cms[cms.length - 1] = null; // Classification scoring } else { Vec lv = scored.lastVec(); double mse = CMTask.MSETask.doTask(scored.add("actual", fr.lastVec())); this.cm = cm.cm; errs[errs.length - 1] = (float)mse; ConfusionMatrix new_cm = new ConfusionMatrix(this.cm); cms[cms.length - 1] = new_cm; // Create the ROC Plot if (classes() == 2) { Vec v = null; Frame fa = null; if (lv.isInt()) { fa = new MRTask2() { @Override public void map(Chunk[] cs, NewChunk nchk) { int rows = cs[0]._len; int cols = cs.length - 1; for (int r = 0; r < rows; ++r) { nchk.addNum(cs[cols].at0(r) == 0 ? 1e-10 : 1.0 - 1e-10); } } }.doAll(1, scored).outputFrame(null,null); v = fa.anyVec(); } AUC auc_calc = new AUC(); auc_calc.vactual = cm.vactual; auc_calc.vpredict = v == null ? lv : v; // lastVec is class1 auc_calc.invoke(); validAUC = auc_calc.data(); if (v != null) UKV.remove(v._key); if (fa != null) fa.delete(); UKV.remove(lv._key); } } scored.remove("actual"); scored.delete(); } private void scoreOnTrain(Frame fr, Vec modelResp) { final CMTask cmTask = CMTask.scoreTask(this, treeCount(), oobee, fr, modelResp); if (regression) { float mse = cmTask._ss / ( (float) (cmTask._rowcnt)); errs[errs.length - 1] = mse; cms[cms.length - 1] = null; } else { confusion = CMTask.CMFinal.make(cmTask._matrix, this, classNames(), cmTask._errorsPerTree, oobee, cmTask._sum, cmTask._cms); this.cm = cmTask._matrix._matrix; errorsPerTree = cmTask._errorsPerTree; errs[errs.length - 1] = confusion.mse(); cms[cms.length - 1] = new ConfusionMatrix(confusion._matrix); if (classes() == 2) validAUC = makeAUC(toCMArray(confusion._cms), ModelUtils.DEFAULT_THRESHOLDS, cmDomain); } } void scoreAllTrees(Frame fr, Vec modelResp) { if (this.validation) scoreOnTest(fr, modelResp); else scoreOnTrain(fr, modelResp); } void variableImportanceCalc(Frame fr, Vec modelResp) { varimp = doVarImpCalc(fr, this, modelResp); } public static SpeeDRFModel make(SpeeDRFModel old, Key tkey, Key dtKey, int nodeIdx, String tString, int tree_id) { // Create a new model for atomic update SpeeDRFModel m = (SpeeDRFModel)old.clone(); // Update the tree keys with the new one (tkey) m.t_keys = Arrays.copyOf(old.t_keys, old.t_keys.length + 1); m.t_keys[m.t_keys.length-1] = tkey; // Update the dtree keys with the new one (dtkey) m.dtreeKeys[tree_id][0] = dtKey; // Update the local_forests m.local_forests[nodeIdx][tree_id] = tkey; // Update the treeStrings? if (old.verbose_output.length < 2) { m.verbose_output = Arrays.copyOf(old.verbose_output, old.verbose_output.length + 1); m.verbose_output[m.verbose_output.length - 1] = tString; } m.errs = Arrays.copyOf(old.errs, old.errs.length+1); m.errs[m.errs.length - 1] = -1.0; m.cms = Arrays.copyOf(old.cms, old.cms.length+1); m.cms[m.cms.length-1] = null; return m; } public String name(int atree) { if( atree == -1 ) atree = size(); assert atree <= size(); return _key.toString() + "[" + atree + "]"; } /** Return the bits for a particular tree */ public byte[] tree(int tree_id) { byte[][] ts = trees; if( ts == null ) trees = ts = new byte[tree_id+1][]; if( tree_id >= ts.length ) trees = ts = Arrays.copyOf(ts,tree_id+1); if( ts[tree_id] == null ) ts[tree_id] = DKV.get(t_keys[tree_id]).memOrLoad(); return ts[tree_id]; } /** Free all internal tree keys. */ @Override public Futures delete_impl(Futures fs) { for( Key k : t_keys ) UKV.remove(k,fs); for (Key[] ka : local_forests) for (Key k : ka) if (k != null) UKV.remove(k, fs); return fs; } /** * Classify a row according to one particular tree. * @param tree_id the number of the tree to use * @param chunks the chunk we are using * @param row the row number in the chunk * @param modelDataMap mapping from model/tree columns to data columns * @return the predicted response class, or class+1 for broken rows */ public float classify0(int tree_id, Chunk[] chunks, int row, int modelDataMap[], short badrow, boolean regression) { return Tree.classify(new AutoBuffer(tree(tree_id)), chunks, row, modelDataMap, badrow, regression); } private void vote(Chunk[] chks, int row, int modelDataMap[], int[] votes) { int numClasses = classes(); assert votes.length == numClasses + 1 /* +1 to catch broken rows */; for( int i = 0; i < treeCount(); i++ ) votes[(int)classify0(i, chks, row, modelDataMap, (short) numClasses, false)]++; } public short classify(Chunk[] chks, int row, int modelDataMap[], int[] votes, double[] classWt, Random rand ) { // Vote all the trees for the row vote(chks, row, modelDataMap, votes); return classify(votes, classWt, rand); } public short classify(int[] votes, double[] classWt, Random rand) { // Scale the votes by class weights: it as-if rows of the weighted classes // were replicated many times so get many votes. if( classWt != null ) for( int i=0; i<votes.length-1; i++ ) votes[i] = (int) (votes[i] * classWt[i]); // Tally results int result = 0; int tied = 1; for( int i = 1; i < votes.length - 1; i++ ) if( votes[i] > votes[result] ) { result=i; tied=1; } else if( votes[i] == votes[result] ) { tied++; } if( tied == 1 ) return (short) result; // Tie-breaker logic int j = rand == null ? 0 : rand.nextInt(tied); // From zero to number of tied classes-1 int k = 0; for( int i = 0; i < votes.length - 1; i++ ) if( votes[i]==votes[result] && (k++ >= j) ) return (short)i; throw H2O.unimpl(); } // The seed for a given tree long seed(int ntree) { return UDP.get8(tree(ntree), 4); } // The producer for a given tree byte producerId(int ntree) { return tree(ntree)[12]; } // Lazy initialization of tree leaves, depth private transient Counter _tl, _td; /** Internal computation of depth and number of leaves. */ public void find_leaves_depth() { // if( _tl != null ) return; _td = new Counter(); _tl = new Counter(); for( Key tkey : t_keys ) { long dl = Tree.depth_leaves(new AutoBuffer(DKV.get(tkey).memOrLoad()), regression); _td.add((int) (dl >> 32)); _tl.add((int) dl); } } public Counter leaves() { find_leaves_depth(); return _tl; } public Counter depth() { find_leaves_depth(); return _td; } private static int find(String n, String[] names) { if( n == null ) return -1; for( int j = 0; j<names.length; j++ ) if( n.equals(names[j]) ) return j; return -1; } public int[] colMap(Frame df) { int res[] = new int[df._names.length]; //new int[names.length]; for(int i = 0; i < res.length; i++) { res[i] = find(df.names()[i], _names); } return res; } @Override protected float[] score0(double[] data, float[] preds) { int numClasses = classes(); if (numClasses == 1) { float p = 0.f; for (int i = 0; i < treeCount(); ++i) { p += Tree.classify(new AutoBuffer(tree(i)), data, 0.0, true) / (1. * treeCount()); } return new float[]{p}; } else { int votes[] = new int[numClasses + 1/* +1 to catch broken rows */]; preds = new float[numClasses + 1]; for( int i = 0; i < treeCount(); i++ ) { // DTree.TreeModel.CompressedTree t = UKV.get(dtreeKeys[i][0]); votes[(int) Tree.classify(new AutoBuffer(tree(i)), data, numClasses, false)]++; } float s = 0.f; for (int v : votes) s += (float)v; if (get_params().balance_classes) { for (int i = 0; i < votes.length - 1; ++i) preds[i+1] = ( (float)votes[i] / treeCount()); return preds; } for (int i = 0; i < votes.length - 1; ++i) preds[i+1] = ( (float)votes[i] / (float)treeCount()); // preds[0] = (float) (classify(votes, null, null) + resp_min); preds[0] = ModelUtils.getPrediction(preds, data); float[] rawp = new float[preds.length + 1]; for (int i = 0; i < votes.length; ++i) rawp[i+1] = (float)votes[i]; return preds; } } @Override public float progress() { return get_params().cv_progress(t_keys.length / (float) N); } static String[] cfDomain(final CMTask.CMFinal cm, int maxClasses) { String[] dom = cm.domain(); if (dom.length > maxClasses) throw new IllegalArgumentException("The column has more than "+maxClasses+" values. Are you sure you have that many classes?"); return dom; } private boolean errsNotNull() { boolean allMinus1 = true; if (errs == null) return false; for (double err : errs) { if (err > -1) allMinus1 = false; } return !allMinus1; } public void generateHTML(String title, StringBuilder sb) { String style = "<style>\n"+ "td, th { min-width:60px;}\n"+ "</style>\n"; sb.append(style); DocGen.HTML.title(sb,title); sb.append("<div class=\"alert\">").append("Actions: "); sb.append(Inspect2.link("Inspect training data (" + _dataKey.toString() + ")", _dataKey)).append(", "); if (validation) sb.append(Inspect2.link("Inspect testing data (" + testKey.toString() + ")", testKey)).append(", "); sb.append(Predict.link(_key, "Score on dataset" )); if (this.size() > 0 && this.size() < N && !Job.findJob(jobKey).isCancelledOrCrashed()) { sb.append(", "); sb.append("<i class=\"icon-stop\"></i> ").append(Cancel.link(jobKey, "Cancel training")); } sb.append("</div>"); DocGen.HTML.paragraph(sb,"Model Key: "+_key); DocGen.HTML.paragraph(sb,"Max max_depth: "+max_depth+", Nbins: "+nbins+", Trees: " + this.size()); DocGen.HTML.paragraph(sb, "Sample Rate: "+sample + ", User Seed: "+get_params().seed+ ", Internal Seed: "+zeed+", mtry: "+mtry); sb.append("</pre>"); if (this.size() > 0 && this.size() < N) sb.append("Current Status: ").append("Building Random Forest"); else { if (this.size() == N && !this.current_status.equals("Performing Variable Importance Calculation.")) { sb.append("Current Status: ").append("Complete."); } else { if( Job.findJob(jobKey).isCancelledOrCrashed()) { sb.append("Current Status: ").append("Cancelled."); } else { sb.append("Current Status: ").append(this.current_status); } } } if (_have_cv_results) { sb.append("<div class=\"alert\">Scoring results reported for ").append(this.parameters.n_folds).append("-fold cross-validated training data ").append(Inspect2.link(_dataKey.toString(), _dataKey)).append("</div>"); } else { if (testKey != null) sb.append("<div class=\"alert\">Reported on ").append(Inspect2.link(testKey.toString(), testKey)).append("</div>"); else sb.append("<div class=\"alert\">Reported on ").append( oobee ? "OOB" : "training" ).append(" data</div>"); } //build cm if(!regression) { // if (confusion != null && confusion.valid() && (this.N * .25 > 0) && classes() >= 2) { // buildCM(sb); // } else { if (this.cms[this.cms.length - 1] != null && (this.N * .25 > 0 && classes() >= 2) ) { this.cms[this.cms.length - 1].toHTML(sb, this.cmDomain); // } } } sb.append("<br />"); if( errsNotNull() && this.size() > 0) { DocGen.HTML.section(sb,"Mean Squared Error by Tree"); DocGen.HTML.arrayHead(sb); sb.append("<tr style='min-width:60px'><th>Trees</th>"); int last = this.size(); // + 1; for( int i=last; i>=0; i-- ) sb.append("<td style='min-width:60px'>").append(i).append("</td>"); sb.append("</tr>"); sb.append("<tr style='min-width: 60px;'><th style='min-width: 60px;' class='warning'>MSE</th>"); for( int i=last; i>=0; i-- ) sb.append( (!(Double.isNaN(errs[i]) || errs[i] <= 0.0)) ? String.format("<td style='min-width:60px'>%5.5f</td>",errs[i]) : "<td style='min-width:60px'>---</td>"); sb.append("</tr>"); DocGen.HTML.arrayTail(sb); } sb.append("<br/>"); JsonObject trees = new JsonObject(); trees.addProperty(Constants.TREE_COUNT, this.size()); if( this.size() > 0 ) { trees.add(Constants.TREE_DEPTH, this.depth().toJson()); trees.add(Constants.TREE_LEAVES, this.leaves().toJson()); } if (validAUC != null) { generateHTMLAUC(sb); } generateHTMLTreeStats(sb, trees); if (varimp != null) { generateHTMLVarImp(sb); } printCrossValidationModelsHTML(sb); } public DTree.TreeModel transform2DTreeModel() { if (dtreeTreeModel != null) { dtreeTreeModel = new SpeeDRFModel_DTree(dtreeTreeModel, dtreeKeys, treeStats); //freshen the dtreeTreeModel return dtreeTreeModel; } Key key = Key.make(); Key model_key = _key; Key dataKey = _dataKey; Key testKey = null; String[] names = _names; String[][] domains = _domains; String[] cmDomain = this.cmDomain; int ntrees = treeCount(); int min_rows = 0; int nbins = this.nbins; int mtries = this.mtry; long seed = -1; int num_folds = 0; float[] priorClassDist = null; float[] classDist = null; // dummy model dtreeTreeModel = new SpeeDRFModel_DTree(model_key, model_key, dataKey,testKey,names,domains,cmDomain,ntrees, max_depth, min_rows, nbins, mtries, num_folds, priorClassDist, classDist); // update the model dtreeTreeModel = new SpeeDRFModel_DTree(dtreeTreeModel, dtreeKeys, treeStats); dtreeTreeModel.isFromSpeeDRF=true; // tells the toJava method the model is translated from a speedrf model. return dtreeTreeModel; } public static class SpeeDRFModel_DTree extends DTree.TreeModel { static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code. Key modelKey; public SpeeDRFModel_DTree(Key key, Key modelKey, Key dataKey, Key testKey, String names[], String domains[][], String[] cmDomain, int ntrees, int max_depth, int min_rows, int nbins, int mtries, int num_folds, float[] priorClassDist, float[] classDist) { super(key,dataKey,testKey,names,domains,cmDomain,ntrees, max_depth, min_rows, nbins, num_folds, priorClassDist, classDist); this.modelKey = modelKey; } public SpeeDRFModel_DTree(SpeeDRFModel_DTree prior, Key[][] treeKeys, TreeStats tstats) { super(prior, treeKeys, null, prior.cms, tstats, null, null); } @Override protected void generateModelDescription(StringBuilder sb) { } } @Override public ModelAutobufferSerializer getModelSerializer() { // Return a serializer which knows how to serialize keys return new ModelAutobufferSerializer() { @Override protected AutoBuffer postSave(Model m, AutoBuffer ab) { int ntrees = N; ab.put4(ntrees); // must fill out t_keys and dtreeKeys for (int i = 0; i < ntrees; ++i) { byte[] bits = tree(i); ab.putA1(bits); for (int j = 0; j < nclasses(); ++j) { if (dtreeKeys[i][j] == null) continue; Value v = DKV.get(dtreeKeys[i][j]); if (v == null) continue; DTree.TreeModel.CompressedTree t = v.get(); ab.put(t); } } return ab; } @Override protected AutoBuffer postLoad(Model m, AutoBuffer ab) { int ntrees = ab.get4(); Futures fs = new Futures(); for (int i = 0; i < ntrees; ++i) { DKV.put(t_keys[i],new Value(t_keys[i],ab.getA1()), fs); for (int j = 0; j < nclasses(); ++j) { if (dtreeKeys[i][j] == null) continue; UKV.put(dtreeKeys[i][j], new Value(dtreeKeys[i][j], ab.get(DTree.TreeModel.CompressedTree.class)), fs); } } fs.blockForPending(); return ab; } }; } static final String NA = "---"; public void generateHTMLTreeStats(StringBuilder sb, JsonObject trees) { DocGen.HTML.section(sb,"Tree stats"); DocGen.HTML.arrayHead(sb); sb.append("<tr><th> </th>").append("<th>Min</th><th>Mean</th><th>Max</th></tr>"); TreeStats treeStats = new TreeStats(); double[] depth_stats = stats(trees.get(Constants.TREE_DEPTH)); double[] leaf_stats = stats(trees.get(Constants.TREE_LEAVES)); sb.append("<tr><th>Depth</th>") .append("<td>").append(depth_stats != null ? (int)depth_stats[0] : NA).append("</td>") .append("<td>").append(depth_stats != null ? depth_stats[1] : NA).append("</td>") .append("<td>").append(depth_stats != null ? (int)depth_stats[2] : NA).append("</td></tr>"); sb.append("<th>Leaves</th>") .append("<td>").append(leaf_stats != null ? (int)leaf_stats[0] : NA).append("</td>") .append("<td>").append(leaf_stats != null ? leaf_stats[1] : NA).append("</td>") .append("<td>").append(leaf_stats != null ? (int)leaf_stats[2] : NA).append("</td></tr>"); DocGen.HTML.arrayTail(sb); if(depth_stats != null && leaf_stats != null) { treeStats.minDepth = (int)depth_stats[0]; treeStats.meanDepth = (float)depth_stats[1]; treeStats.maxDepth = (int)depth_stats[2]; treeStats.minLeaves = (int)leaf_stats[0]; treeStats.meanLeaves = (float)leaf_stats[1]; treeStats.maxLeaves = (int)leaf_stats[2]; treeStats.setNumTrees(N); } else { treeStats = null; } this.treeStats = treeStats; } private static double[] stats(JsonElement json) { if( json == null ) { return null; } else { JsonObject obj = json.getAsJsonObject(); return new double[]{ Math.round(obj.get(Constants.MIN).getAsDouble() * 1000.0) / 1000.0, Math.round(obj.get(Constants.MEAN).getAsDouble() * 1000.0) / 1000.0, Math.round(obj.get(Constants.MAX).getAsDouble() * 1000.0) / 1000.0}; } } public void buildCM(StringBuilder sb) { int tasks = this.N; int finished = this.size(); int modelSize = tasks * 25/100; modelSize = modelSize == 0 || finished==tasks ? finished : modelSize * (finished/modelSize); if (confusion!=null && confusion.valid() && modelSize > 0) { //finished += 1; JsonObject cm = new JsonObject(); JsonArray cmHeader = new JsonArray(); JsonArray matrix = new JsonArray(); cm.addProperty(JSON_CM_TYPE, oobee ? "OOB" : "training"); cm.addProperty(JSON_CM_CLASS_ERR, confusion.classError()); cm.addProperty(JSON_CM_ROWS_SKIPPED, confusion.skippedRows()); cm.addProperty(JSON_CM_ROWS, confusion.rows()); // create the header for (String s : cfDomain(confusion, 1024)) cmHeader.add(new JsonPrimitive(s)); cm.add(JSON_CM_HEADER,cmHeader); // add the matrix final int nclasses = confusion.dimension(); JsonArray classErrors = new JsonArray(); for (int crow = 0; crow < nclasses; ++crow) { JsonArray row = new JsonArray(); int classHitScore = 0; for (int ccol = 0; ccol < nclasses; ++ccol) { row.add(new JsonPrimitive(confusion.matrix(crow,ccol))); if (crow!=ccol) classHitScore += confusion.matrix(crow,ccol); } // produce infinity members in case of 0.f/0 classErrors.add(new JsonPrimitive((float)classHitScore / (classHitScore + confusion.matrix(crow,crow)))); matrix.add(row); } cm.add(JSON_CM_CLASSES_ERRORS, classErrors); cm.add(JSON_CM_MATRIX,matrix); cm.addProperty(JSON_CM_TREES,modelSize); // Signal end only and only if all trees were generated and confusion matrix is valid DocGen.HTML.section(sb, "Confusion Matrix:"); if (cm.has(JSON_CM_MATRIX)) { sb.append("<dl class='dl-horizontal'>"); sb.append("<dt>classification error</dt><dd>").append(String.format("%5.5f %%", 100*cm.get(JSON_CM_CLASS_ERR).getAsFloat())).append("</dd>"); long rows = cm.get(JSON_CM_ROWS).getAsLong(); long skippedRows = cm.get(JSON_CM_ROWS_SKIPPED).getAsLong(); sb.append("<dt>used / skipped rows </dt><dd>").append(String.format("%d / %d (%3.1f %%)", rows, skippedRows, (double)skippedRows*100/(skippedRows+rows))).append("</dd>"); sb.append("<dt>trees used</dt><dd>").append(cm.get(JSON_CM_TREES).getAsInt()).append("</dd>"); sb.append("</dl>"); sb.append("<table class='table table-striped table-bordered table-condensed'>"); sb.append("<tr style='min-width: 60px;'><th style='min-width: 60px;'>Actual \\ Predicted</th>"); JsonArray header = (JsonArray) cm.get(JSON_CM_HEADER); for (JsonElement e: header) sb.append("<th style='min-width: 60px;'>").append(e.getAsString()).append("</th>"); sb.append("<th style='min-width: 60px;'>Error</th></tr>"); int classes = header.size(); long[] totals = new long[classes]; JsonArray matrix2 = (JsonArray) cm.get(JSON_CM_MATRIX); long sumTotal = 0; long sumError = 0; for (int crow = 0; crow < classes; ++crow) { JsonArray row = (JsonArray) matrix2.get(crow); long total = 0; long error = 0; sb.append("<tr style='min-width: 60px;'><th style='min-width: 60px;'>").append(header.get(crow).getAsString()).append("</th>"); for (int ccol = 0; ccol < classes; ++ccol) { long num = row.get(ccol).getAsLong(); total += num; totals[ccol] += num; if (ccol == crow) { sb.append("<td style='background-color:LightGreen; min-width: 60px;'>"); } else { sb.append("<td styile='min-width: 60px;'>"); error += num; } sb.append(num); sb.append("</td>"); } sb.append("<td style='min-width: 60px;'>"); sb.append(String.format("%.05f = %,d / %d", (double)error/total, error, total)); sb.append("</td></tr>"); sumTotal += total; sumError += error; } sb.append("<tr style='min-width: 60px;'><th style='min-width: 60px;'>Totals</th>"); for (long total : totals) sb.append("<td style='min-width: 60px;'>").append(total).append("</td>"); sb.append("<td style='min-width: 60px;'><b>"); sb.append(String.format("%.05f = %,d / %d", (double)sumError/sumTotal, sumError, sumTotal)); sb.append("</b></td></tr>"); sb.append("</table>"); } else { sb.append("<div class='alert alert-info'>"); sb.append("Confusion matrix is being computed into the key:</br>"); sb.append(cm.get(JSON_CONFUSION_KEY).getAsString()); sb.append("</div>"); } } } private static ConfusionMatrix[] toCMArray(long[][][] cms) { int n = cms.length; ConfusionMatrix[] res = new ConfusionMatrix[n]; for (int i = 0; i < n; i++) res[i] = new ConfusionMatrix(cms[i]); return res; } protected static AUCData makeAUC(ConfusionMatrix[] cms, float[] threshold, String[] cmDomain) { return cms != null ? new AUC(cms, threshold, cmDomain).data() : null; } protected void generateHTMLAUC(StringBuilder sb) { validAUC.toHTML(sb); } protected void generateHTMLVarImp(StringBuilder sb) { if (varimp!=null) { // Set up variable names for importance varimp.setVariables(Arrays.copyOf(_names, _names.length-1)); varimp.toHTML(this, sb); } } protected VarImp doVarImpCalc(final Frame fr, final SpeeDRFModel model, final Vec resp) { _treeMeasuresOnOOB = new VariableImportance.TreeVotes[fr.numCols() - 1]; _treeMeasuresOnSOOB = new VariableImportance.TreeVotes[fr.numCols() - 1]; for (int i=0; i<fr.numCols() - 1; i++) _treeMeasuresOnOOB[i] = new VariableImportance.TreeVotes(treeCount()); for (int i=0; i<fr.numCols() - 1; i++) _treeMeasuresOnSOOB[i] = new VariableImportance.TreeVotes(treeCount()); final int ncols = fr.numCols(); final int trees = treeCount(); for (int i=0; i<ncols - 1; i++) _treeMeasuresOnSOOB[i] = new VariableImportance.TreeVotes(trees); Futures fs = new Futures(); for (int var=0; var<ncols - 1; var++) { final int variable = var; H2O.H2OCountedCompleter task4var = new H2O.H2OCountedCompleter() { @Override public void compute2() { VariableImportance.TreeVotes[] cd = VariableImportance.collectVotes(trees, classes(), fr, ncols - 1, sample, variable, model, resp); asVotes(_treeMeasuresOnOOB[variable]).append(cd[0]); asVotes(_treeMeasuresOnSOOB[variable]).append(cd[1]); tryComplete(); } }; H2O.submitTask(task4var); fs.add(task4var); } fs.blockForPending(); // Compute varimp for individual features (_ncols) final float[] varimp = new float[ncols - 1]; // output variable importance float[] varimpSD = new float[ncols - 1]; // output variable importance sd for (int var=0; var<ncols - 1; var++) { long[] votesOOB = asVotes(_treeMeasuresOnOOB[var]).votes(); long[] votesSOOB = asVotes(_treeMeasuresOnSOOB[var]).votes(); float imp = 0.f; float v = 0.f; long[] nrows = asVotes(_treeMeasuresOnOOB[var]).nrows(); for (int i = 0; i < votesOOB.length; ++i) { double delta = ((float) (votesOOB[i] - votesSOOB[i])) / (float) nrows[i]; imp += delta; v += delta * delta; } imp /= model.treeCount(); varimp[var] = imp; varimpSD[var] = (float)Math.sqrt( (v/model.treeCount() - imp*imp) / model.treeCount() ); } return new VarImp.VarImpMDA(varimp, varimpSD, model.treeCount()); } public static float[] computeVarImpSD(float[][] vote_diffs) { float[] res = new float[vote_diffs.length]; for (int var = 0; var < vote_diffs.length; ++var) { float mean_diffs = 0.f; float r = 0.f; for (float d: vote_diffs[var]) mean_diffs += d / (float) vote_diffs.length; for (float d: vote_diffs[var]) { r += (d - mean_diffs) * (d - mean_diffs); } r *= 1.f / (float)vote_diffs[var].length; res[var] = (float) Math.sqrt(r); } return res; } @Override protected void setCrossValidationError(Job.ValidatedJob job, double cv_error, water.api.ConfusionMatrix cm, AUCData auc, HitRatio hr) { _have_cv_results = true; SpeeDRFModel drfm = ((SpeeDRF)job).makeModel(this, cv_error, cm.cm == null ? null : new ConfusionMatrix(cm.cm, this.nclasses()), this.varimp, auc); drfm._have_cv_results = true; DKV.put(this._key, drfm); //overwrite this model } }