package hex.pca; import hex.FrameTask.DataInfo; import hex.gram.Gram.GramTask; import water.Key; import water.MemoryManager; import water.Model; import water.Request2; import water.api.DocGen; import water.api.Request.API; import water.api.RequestBuilders.ElementBuilder; public class PCAModel extends Model { static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code. @API(help = "Column names expanded to accommodate categoricals") final String[] namesExp; @API(help = "Standard deviation of each principal component") final double[] sdev; @API(help = "Proportion of variance explained by each principal component") final double[] propVar; @API(help = "Cumulative proportion of variance explained by each principal component") final double[] cumVar; @API(help = "Principal components (eigenvector) matrix") final double[][] eigVec; @API(help = "If standardized, mean of each numeric data column") final double[] normSub; @API(help = "If standardized, one over standard deviation of each numeric data column") final double[] normMul; @API(help = "Offsets of categorical columns into the sdev vector. The last value is the offset of the first numerical column.") final int[] catOffsets; @API(help = "Rank of eigenvector matrix") final int rank; @API(help = "Number of principal components to display") int num_pc; @API(help = "Model parameters") PCA parameters; public PCAModel(PCA params, Key selfKey, Key dataKey, DataInfo dinfo, GramTask gramt, double[] sdev, double[] propVar, double[] cumVar, double[][] eigVec, int rank, int num_pc) { super(selfKey, dataKey, dinfo._adaptedFrame, /* priorClassDistribution */ null); this.sdev = sdev; this.propVar = propVar; this.cumVar = cumVar; this.eigVec = eigVec; this.parameters = params; this.catOffsets = dinfo._catOffsets; this.namesExp = namesExp(); this.rank = rank; this.num_pc = num_pc; // TODO: Need to ensure this maps correctly to scored data cols this.normSub = gramt.normSub(); this.normMul = gramt.normMul(); } @Override public final PCA get_params() { return parameters; } @Override public final Request2 job() { return get_params(); } @Override public int nfeatures() { return _names.length; } @Override public boolean isSupervised() { return false; } @Override public String responseName() { throw new IllegalArgumentException("PCA doesn't have a response."); } public double[] sdev() { return sdev; } public double[][] eigVec() { return eigVec; } @Override protected float[] score0(double[] data, float[] preds) { throw new RuntimeException("TODO Auto-generated method stub"); } @Override public String toString(){ StringBuilder sb = new StringBuilder("PCA Model (key=" + _key + " , trained on " + _dataKey + "):\n"); return sb.toString(); } public String[] namesExp(){ final int n = _names.length; int[] nums = MemoryManager.malloc4(n); int[] cats = MemoryManager.malloc4(n); // Store indices of numeric and categorical cols int nnums = 0, ncats = 0; for(int i = 0; i < n; ++i){ if(_domains[i] != null) cats[ncats++] = i; else nums[nnums++] = i; } // Sort the categoricals in decreasing order according to size for(int i = 0; i < ncats; ++i) for(int j = i+1; j < ncats; ++j) if(_domains[cats[i]].length < _domains[cats[j]].length) { int x = cats[i]; cats[i] = cats[j]; cats[j] = x; } // Construct expanded col names, with categoricals first followed by numerics int k = 0; String[] names = new String[sdev.length]; for(int i = 0; i < ncats; ++i){ for(int j = 1; j < _domains[cats[i]].length; ++j) names[k++] = _names[cats[i]] + "." + _domains[cats[i]][j]; } for(int i = 0; i < nnums; ++i) { names[k++] = _names[nums[i]]; } return names; } public void generateHTML(String title, StringBuilder sb) { if(title != null && !title.isEmpty()) DocGen.HTML.title(sb, title); DocGen.HTML.paragraph(sb, "Model Key: " + _key); job().toHTML(sb); sb.append("<script type=\"text/javascript\" src='/h2o/js/d3.v3.min.js'></script>"); sb.append("<div class='alert'>Actions: " + PCAScore.link(_key, "Score on dataset") + (_dataKey != null ? (", " + PCA.link(_dataKey, "Compute new model")):"") + "</div>"); screevarString(sb); sb.append("<span style='display: inline-block;'>"); sb.append("<table class='table table-striped table-bordered'>"); sb.append("<tr>"); sb.append("<th>Feature</th>"); for(int i = 0; i < num_pc; i++) sb.append("<th>").append("PC" + i).append("</th>"); sb.append("</tr>"); // Row of standard deviation values sb.append("<tr class='warning'>"); // sb.append("<td>").append("σ").append("</td>"); sb.append("<td>").append("Std Dev").append("</td>"); for(int c = 0; c < num_pc; c++) sb.append("<td>").append(ElementBuilder.format(sdev[c])).append("</td>"); sb.append("</tr>"); // Row with proportion of variance sb.append("<tr class='warning'>"); sb.append("<td>").append("Prop Var").append("</td>"); for(int c = 0; c < num_pc; c++) sb.append("<td>").append(ElementBuilder.format(propVar[c])).append("</td>"); sb.append("</tr>"); // Row with cumulative proportion of variance sb.append("<tr class='warning'>"); sb.append("<td>").append("Cum Prop Var").append("</td>"); for(int c = 0; c < num_pc; c++) sb.append("<td>").append(ElementBuilder.format(cumVar[c])).append("</td>"); sb.append("</tr>"); // Each row is component of eigenvector for(int r = 0; r < eigVec.length; r++) { sb.append("<tr>"); sb.append("<th>").append(namesExp[r]).append("</th>"); for( int c = 0; c < num_pc; c++ ) { double e = eigVec[r][c]; sb.append("<td>").append(ElementBuilder.format(e)).append("</td>"); } sb.append("</tr>"); } sb.append("</table></span>"); } public void screevarString(StringBuilder sb) { sb.append("<div class=\"pull-left\"><a href=\"#\" onclick=\'$(\"#scree_var\").toggleClass(\"hide\");\' class=\'btn btn-inverse btn-mini\'>Scree & Variance Plots</a></div>"); sb.append("<div class=\"hide\" id=\"scree_var\">"); sb.append("<style type=\"text/css\">"); sb.append(".axis path," + ".axis line {\n" + "fill: none;\n" + "stroke: black;\n" + "shape-rendering: crispEdges;\n" + "}\n" + ".axis text {\n" + "font-family: sans-serif;\n" + "font-size: 11px;\n" + "}\n"); sb.append("</style>"); sb.append("<div id=\"scree\" style=\"display:inline;\">"); sb.append("<script type=\"text/javascript\">"); sb.append("//Width and height\n"); sb.append("var w = 500;\n"+ "var h = 300;\n"+ "var padding = 40;\n" ); sb.append("var dataset = ["); for(int c = 0; c < num_pc; c++) { if (c == 0) { sb.append("["+String.valueOf(c+1)+",").append(ElementBuilder.format(sdev[c]*sdev[c])).append("]"); } sb.append(", ["+String.valueOf(c+1)+",").append(ElementBuilder.format(sdev[c]*sdev[c])).append("]"); } sb.append("];"); sb.append( "//Create scale functions\n"+ "var xScale = d3.scale.linear()\n"+ ".domain([0, d3.max(dataset, function(d) { return d[0]; })])\n"+ ".range([padding, w - padding * 2]);\n"+ "var yScale = d3.scale.linear()"+ ".domain([0, d3.max(dataset, function(d) { return d[1]; })])\n"+ ".range([h - padding, padding]);\n"+ "var rScale = d3.scale.linear()"+ ".domain([0, d3.max(dataset, function(d) { return d[1]; })])\n"+ ".range([2, 5]);\n"+ "//Define X axis\n"+ "var xAxis = d3.svg.axis()\n"+ ".scale(xScale)\n"+ ".orient(\"bottom\")\n"+ ".ticks(5);\n"+ "//Define Y axis\n"+ "var yAxis = d3.svg.axis()\n"+ ".scale(yScale)\n"+ ".orient(\"left\")\n"+ ".ticks(5);\n"+ "//Create SVG element\n"+ "var svg = d3.select(\"#scree\")\n"+ ".append(\"svg\")\n"+ ".attr(\"width\", w)\n"+ ".attr(\"height\", h);\n"+ "//Create circles\n"+ "svg.selectAll(\"circle\")\n"+ ".data(dataset)\n"+ ".enter()\n"+ ".append(\"circle\")\n"+ ".attr(\"cx\", function(d) {\n"+ "return xScale(d[0]);\n"+ "})\n"+ ".attr(\"cy\", function(d) {\n"+ "return yScale(d[1]);\n"+ "})\n"+ ".attr(\"r\", function(d) {\n"+ "return 2;\n"+//rScale(d[1]);\n"+ "});\n"+ "/*"+ "//Create labels\n"+ "svg.selectAll(\"text\")"+ ".data(dataset)"+ ".enter()"+ ".append(\"text\")"+ ".text(function(d) {"+ "return d[0] + \",\" + d[1];"+ "})"+ ".attr(\"x\", function(d) {"+ "return xScale(d[0]);"+ "})"+ ".attr(\"y\", function(d) {"+ "return yScale(d[1]);"+ "})"+ ".attr(\"font-family\", \"sans-serif\")"+ ".attr(\"font-size\", \"11px\")"+ ".attr(\"fill\", \"red\");"+ "*/\n"+ "//Create X axis\n"+ "svg.append(\"g\")"+ ".attr(\"class\", \"axis\")"+ ".attr(\"transform\", \"translate(0,\" + (h - padding) + \")\")"+ ".call(xAxis);\n"+ "//X axis label\n"+ "d3.select('#scree svg')"+ ".append(\"text\")"+ ".attr(\"x\",w/2)"+ ".attr(\"y\",h - 5)"+ ".attr(\"text-anchor\", \"middle\")"+ ".text(\"Principal Component\");\n"+ "//Create Y axis\n"+ "svg.append(\"g\")"+ ".attr(\"class\", \"axis\")"+ ".attr(\"transform\", \"translate(\" + padding + \",0)\")"+ ".call(yAxis);\n"+ "//Y axis label\n"+ "d3.select('#scree svg')"+ ".append(\"text\")"+ ".attr(\"x\",150)"+ ".attr(\"y\",-5)"+ ".attr(\"transform\", \"rotate(90)\")"+ //".attr(\"transform\", \"translate(0,\" + (h - padding) + \")\")"+ ".attr(\"text-anchor\", \"middle\")"+ ".text(\"Eigenvalue\");\n"+ "//Title\n"+ "d3.select('#scree svg')"+ ".append(\"text\")"+ ".attr(\"x\",w/2)"+ ".attr(\"y\",padding - 20)"+ ".attr(\"text-anchor\", \"middle\")"+ ".text(\"Scree Plot\");\n"); sb.append("</script>"); sb.append("</div>"); /////////////////////////////////// sb.append("<div id=\"var\" style=\"display:inline;\">"); sb.append("<script type=\"text/javascript\">"); sb.append("//Width and height\n"); sb.append("var w = 500;\n"+ "var h = 300;\n"+ "var padding = 50;\n" ); sb.append("var dataset = ["); for(int c = 0; c < num_pc; c++) { if (c == 0) { sb.append("["+String.valueOf(c+1)+",").append(ElementBuilder.format(cumVar[c])).append("]"); } sb.append(", ["+String.valueOf(c+1)+",").append(ElementBuilder.format(cumVar[c])).append("]"); } sb.append("];"); sb.append( "//Create scale functions\n"+ "var xScale = d3.scale.linear()\n"+ ".domain([0, d3.max(dataset, function(d) { return d[0]; })])\n"+ ".range([padding, w - padding * 2]);\n"+ "var yScale = d3.scale.linear()"+ ".domain([0, d3.max(dataset, function(d) { return d[1]; })])\n"+ ".range([h - padding, padding]);\n"+ "var rScale = d3.scale.linear()"+ ".domain([0, d3.max(dataset, function(d) { return d[1]; })])\n"+ ".range([2, 5]);\n"+ "//Define X axis\n"+ "var xAxis = d3.svg.axis()\n"+ ".scale(xScale)\n"+ ".orient(\"bottom\")\n"+ ".ticks(5);\n"+ "//Define Y axis\n"+ "var yAxis = d3.svg.axis()\n"+ ".scale(yScale)\n"+ ".orient(\"left\")\n"+ ".ticks(5);\n"+ "//Create SVG element\n"+ "var svg = d3.select(\"#var\")\n"+ ".append(\"svg\")\n"+ ".attr(\"width\", w)\n"+ ".attr(\"height\", h);\n"+ "//Create circles\n"+ "svg.selectAll(\"circle\")\n"+ ".data(dataset)\n"+ ".enter()\n"+ ".append(\"circle\")\n"+ ".attr(\"cx\", function(d) {\n"+ "return xScale(d[0]);\n"+ "})\n"+ ".attr(\"cy\", function(d) {\n"+ "return yScale(d[1]);\n"+ "})\n"+ ".attr(\"r\", function(d) {\n"+ "return 2;\n"+//rScale(d[1]);\n"+ "});\n"+ "/*"+ "//Create labels\n"+ "svg.selectAll(\"text\")"+ ".data(dataset)"+ ".enter()"+ ".append(\"text\")"+ ".text(function(d) {"+ "return d[0] + \",\" + d[1];"+ "})"+ ".attr(\"x\", function(d) {"+ "return xScale(d[0]);"+ "})"+ ".attr(\"y\", function(d) {"+ "return yScale(d[1]);"+ "})"+ ".attr(\"font-family\", \"sans-serif\")"+ ".attr(\"font-size\", \"11px\")"+ ".attr(\"fill\", \"red\");"+ "*/\n"+ "//Create X axis\n"+ "svg.append(\"g\")"+ ".attr(\"class\", \"axis\")"+ ".attr(\"transform\", \"translate(0,\" + (h - padding) + \")\")"+ ".call(xAxis);\n"+ "//X axis label\n"+ "d3.select('#var svg')"+ ".append(\"text\")"+ ".attr(\"x\",w/2)"+ ".attr(\"y\",h - 5)"+ ".attr(\"text-anchor\", \"middle\")"+ ".text(\"Principal Component\");\n"+ "//Create Y axis\n"+ "svg.append(\"g\")"+ ".attr(\"class\", \"axis\")"+ ".attr(\"transform\", \"translate(\" + padding + \",0)\")"+ ".call(yAxis);\n"+ "//Y axis label\n"+ "d3.select('#var svg')"+ ".append(\"text\")"+ ".attr(\"x\",150)"+ ".attr(\"y\",-5)"+ ".attr(\"transform\", \"rotate(90)\")"+ //".attr(\"transform\", \"translate(0,\" + (h - padding) + \")\")"+ ".attr(\"text-anchor\", \"middle\")"+ ".text(\"Cumulative Proportion of Variance\");\n"+ "//Title\n"+ "d3.select('#var svg')"+ ".append(\"text\")"+ ".attr(\"x\",w/2)"+ ".attr(\"y\",padding-20)"+ ".attr(\"text-anchor\", \"middle\")"+ ".text(\"Cumulative Variance Plot\");\n"); sb.append("</script>"); sb.append("</div>"); sb.append("</div>"); sb.append("<br />"); } }