package hex; import water.Iced; import water.Model; import water.api.DocGen; import water.api.Request.API; import water.util.UIUtils; import water.util.Utils; import java.util.Arrays; import java.util.Comparator; public class VarImp extends Iced { static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code. /** Variable importance measurement method. */ enum VarImpMethod { PERMUTATION_IMPORTANCE("Mean decrease accuracy"), RELATIVE_IMPORTANCE("Relative importance"); private final String title; VarImpMethod(String title) { this.title = title; } @Override public String toString() { return title; } } @API(help="Variable importance of individual variables.") public float[] varimp; @API(help="Names of variables.") protected String[] variables; @API(help="Variable importance measurement method.") public final VarImpMethod method; @API(help="Max. number of variables to show.") public final int max_var = 100; @API(help="Scaled measurements.") public final boolean scaled() { return false; } public VarImp(float[] varimp) { this(varimp, null, VarImpMethod.RELATIVE_IMPORTANCE); } public VarImp(float[] varimp, String[] variables) { this(varimp, variables, VarImpMethod.RELATIVE_IMPORTANCE); } protected VarImp(float[] varimp, String[] variables, VarImpMethod method) { this.varimp = varimp; this.variables = variables; this.method = method; } public String[] getVariables() { return variables; } public void setVariables(String[] variables) { this.variables = variables; } /** Generate variable importance HTML code. */ public final <T extends Model> StringBuilder toHTML(T model, StringBuilder sb) { DocGen.HTML.section(sb,"Variable importance of input variables: " + method); sb.append("<div class=\"alert\">"); sb.append(UIUtils.builderModelLink(model.getClass(), model._dataKey, model.responseName(), "Build a new model using selected variables", "redirectWithCols(this,'vi_chkb')")); sb.append("</div>"); DocGen.HTML.arrayHead(sb); // Create a sort order Integer[] sortOrder = getSortOrder(); // Generate variable labels and raw scores if (variables != null) DocGen.HTML.tableLine(sb, "Variable", variables, sortOrder, Math.min(max_var, variables.length), true, "vi_chkb"); if (varimp != null) DocGen.HTML.tableLine(sb, method.toString(), varimp, sortOrder, Math.min(max_var, variables.length)); // Print a specific information toHTMLAppendMoreTableLines(sb, sortOrder); DocGen.HTML.arrayTail(sb); // Generate nice graph ;-) toHTMLGraph(sb, sortOrder); // And return the result return sb; } protected StringBuilder toHTMLAppendMoreTableLines(StringBuilder sb, Integer[] sortOrder) { return sb; } protected StringBuilder toHTMLGraph(StringBuilder sb, Integer[] sortOrder) { return toHTMLGraph(sb, variables, varimp, sortOrder, max_var); } static final StringBuilder toHTMLGraph(StringBuilder sb, String[] names, float[] vals, Integer[] sortOrder, int max) { Integer[] so = vals.length > max ? sortOrder : null; // Generate a graph DocGen.HTML.graph(sb, "graphvarimp", "g_varimp", DocGen.HTML.toJSArray(new StringBuilder(), names, so, Math.min(max, vals.length)), DocGen.HTML.toJSArray(new StringBuilder(), vals , so, Math.min(max, vals.length)) ); sb.append("<button id=\"sortBars\" class=\"btn btn-primary\">Sort</button>\n"); return sb; } /** By default provides a sort order according to raw scores stored in <code>varimp</code>. */ protected Integer[] getSortOrder() { Integer[] sortOrder = new Integer[varimp.length]; for(int i=0; i<sortOrder.length; i++) sortOrder[i] = i; Arrays.sort(sortOrder, new Comparator<Integer>() { @Override public int compare(Integer o1, Integer o2) { float f = varimp[o1]-varimp[o2]; return f<0 ? 1 : (f>0 ? -1 : 0); } }); return sortOrder; } /** Variable importance measured as relative influence. * It provides raw values, scaled values, and summary. * Motivate by R's GBM package. */ public static class VarImpRI extends VarImp { static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code. public VarImpRI(float[] varimp) { super(varimp); } @API(help = "Scaled values of raw scores with respect to maximal value (GBM call - relative.influnce(model, scale=T)).") public float[] scaled_values() { float[] scaled = new float[varimp.length]; int maxVar = 0; for (int i=0; i<varimp.length; i++) if (varimp[i] > varimp[maxVar]) maxVar = i; float maxVal = varimp[maxVar]; for (int var=0; var<varimp.length; var++) scaled[var] = varimp[var] / maxVal; return scaled; } @API(help = "Summary of values in percent (the same as produced by summary.gbm).") public float[] summary() { float[] summary = new float[varimp.length]; float sum = Utils.sum(varimp); for (int var=0; var<varimp.length; var++) summary[var] = 100*varimp[var] / sum; return summary; } @Override protected StringBuilder toHTMLAppendMoreTableLines(StringBuilder sb, Integer[] sortOrder ) { StringBuilder ssb = super.toHTMLAppendMoreTableLines(sb, sortOrder); DocGen.HTML.tableLine(sb, "Scaled values", scaled_values(), sortOrder, Math.min(max_var, varimp.length)); DocGen.HTML.tableLine(sb, "Influence in %", summary(), sortOrder, Math.min(max_var, varimp.length)); return ssb; } @Override protected StringBuilder toHTMLGraph(StringBuilder sb, Integer[] sortOrder) { return toHTMLGraph(sb, variables, scaled_values(), sortOrder, max_var ); } } /** Variable importance measured as mean decrease in accuracy. * It provides raw variable importance measures, SD and z-scores. */ public static class VarImpMDA extends VarImp { static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code. @API(help="Variable importance SD for individual variables.") public final float[] varimpSD; /** Number of trees participating for producing variable importance measurements */ private final int ntrees; public VarImpMDA(float[] varimp, float[] varimpSD, int ntrees) { super(varimp,null,VarImpMethod.PERMUTATION_IMPORTANCE); this.varimpSD = varimpSD; this.ntrees = ntrees; } @API(help = "Z-score for individual variables") public float[] z_score() { float[] zscores = new float[varimp.length]; double rnt = Math.sqrt(ntrees); for(int v = 0; v < varimp.length ; v++) zscores[v] = (float) (varimp[v] / (varimpSD[v] / rnt)); return zscores; } @Override protected StringBuilder toHTMLAppendMoreTableLines(StringBuilder sb, Integer[] sortOrder ) { StringBuilder ssb = super.toHTMLAppendMoreTableLines(sb, sortOrder); if (varimpSD!=null) { DocGen.HTML.tableLine(sb, "SD", varimpSD, sortOrder, Math.min(max_var, varimp.length)); float[] zscores = z_score(); DocGen.HTML.tableLine(sb, "Z-scores", zscores, sortOrder, Math.min(max_var, varimp.length)); } return ssb; } } }