package hex.gapstat; import water.Job; import water.Key; import water.Model; import water.api.DocGen; import water.api.Request.API; import water.fvec.Frame; import water.util.D3Plot; public class GapStatisticModel extends Model implements Job.Progress { static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code. // @API(help = "Number of clusters to build in each iteration.") final int ks; @API(help = "The initial pooled within cluster sum of squares for each iteration.") final double[] wks; @API(help = "The log of the Wks.") final double[] wkbs; @API(help = "The standard error from the Monte Carlo simulated data for each iteration.") final double[] sk; // @API(help = "k_max.") final int k_max; // @API(help = "b_max.") final int b_max; // @API(help = "The current value of k_max: (2 <= k <= k_max).") int k; // @API(help = "The current value of B (1 <= b <= B.") int b; @API(help = "The gap statistics per value of k.") double[] gap_stats; @API(help = "Optimal number of clusters.") int k_best = 1; public GapStatisticModel(Key selfKey, Key dataKey, Frame fr, int ks, double[] wks, double[] log_wks, double[] sk, int k_max, int b_max, int k, int b) { super(selfKey, dataKey, fr, /* priorClassDistribution */ null); this.ks = ks; this.wks = wks; this.wkbs = log_wks; this.sk = sk; this.k_max = k_max; this.b_max = b_max; this.k = k; this.b = b; this.gap_stats = new double[this.wks.length]; } public double[] wks() { return wks; } public double[] wkbs() { return wkbs; } public double[] sk() {return sk; } public double[] gaps() {return gap_stats; } @Override public float progress() { return ((k-1)*(b_max+1) + b + 1)/ (float)(k_max*(b_max+1)); } @Override protected float[] score0(double[] data, float[] preds) { throw new UnsupportedOperationException(); } @Override public String toString(){ return String.format("Gap Statistic Model (key=%s , trained on %s):\n", _key, _dataKey); } public void generateHTML(String title, StringBuilder sb) { if(title != null && !title.isEmpty()) DocGen.HTML.title(sb, title); DocGen.HTML.paragraph(sb, "Model Key: " + _key); // sb.append("<div class='alert'>Actions: " + Predict.link(_key, "Predict on dataset") + ", " // + NaiveBayes.link(_dataKey, "Compute new model") + "</div>"); DocGen.HTML.section(sb, "Gap Statistic Output:"); //Log Pooled Variances... DocGen.HTML.section(sb, "Log of the Pooled Cluster Within Sum of Squares per value of k"); sb.append("<span style='display: inline-block;'>"); sb.append("<table class='table table-striped table-bordered'>"); double[] log_wks = wks(); sb.append("<tr>"); for (int i = 0; i <log_wks.length; ++i) { if (log_wks[i] == 0) continue; sb.append("<th>").append(i+1).append("</th>"); } sb.append("</tr>"); sb.append("<tr>"); for (double log_wk : log_wks) { if (log_wk == 0) continue; sb.append("<td>").append(log_wk).append("</td>"); } sb.append("</tr>"); sb.append("</table></span>"); //Monte Carlo Bootstrap averages DocGen.HTML.section(sb, "Monte Carlo Bootstrap Replicate Averages of the Log of the Pooled Cluster Within SS per value of k"); sb.append("<span style='display: inline-block;'>"); sb.append("<table class='table table-striped table-bordered'>"); double[] log_wkbs = wkbs(); sb.append("<tr>"); for (int i = 0; i <log_wkbs.length; ++i) { if (log_wkbs[i] == 0) continue; sb.append("<th>").append(i+1).append("</th>"); } sb.append("</tr>"); sb.append("<tr>"); for (double log_wkb : log_wkbs) { if (log_wkb == 0) continue; sb.append("<td>").append(log_wkb).append("</td>"); } sb.append("</tr>"); sb.append("</table></span>"); //standard errors DocGen.HTML.section(sb, "Standard Error for the Monte Carlo Bootstrap Replicate Averages of the Log of the Pooled Cluster Within SS per value of k"); sb.append("<span style='display: inline-block;'>"); sb.append("<table class='table table-striped table-bordered'>"); double[] sks = sk(); sb.append("<tr>"); for (int i = 0; i <sks.length; ++i) { if (sks[i] == 0) continue; sb.append("<th>").append(i+1).append("</th>"); } sb.append("</tr>"); sb.append("<tr>"); for (double sk1 : sks) { if (sk1 == 0) continue; sb.append("<td>").append(sk1).append("</td>"); } sb.append("</tr>"); sb.append("</table></span>"); //Gap computation DocGen.HTML.section(sb, "Gap Statistic per value of k"); sb.append("<span style='display: inline-block;'>"); sb.append("<table class='table table-striped table-bordered'>"); sb.append("<tr>"); for (int i = 0; i < log_wkbs.length; ++i) { if (log_wkbs[i] == 0) continue; sb.append("<th>").append(i+1).append("</th>"); } sb.append("</tr>"); double[] gaps = gaps(); sb.append("<tr>"); for (double gap : gaps) { if (gap == 0) continue; sb.append("<td>").append(gap).append("</td>"); } sb.append("</tr>"); sb.append("</table></span>"); //Compute optimal k: min k such that G_k >= G_(k+1) - s_(k+1) int kmin = compute_k_best(); if (log_wks[log_wks.length - 1] != 0) { DocGen.HTML.section(sb, "Best k:"); if (kmin <= 1) { sb.append("No optimal number of clusters found (best k = 1)."); } else { sb.append("k = ").append(kmin); } } else { DocGen.HTML.section(sb, "Best k so far:"); if (kmin <= 1) { sb.append("No k computed yet..."); } else { sb.append("k = ").append(kmin); } } float[] K = new float[ks]; float[] wks_y = new float[ks]; for(int i = 0; i < wks.length; ++i){ assert wks.length == ks; K[i] = i + 1; wks_y[i] = (float)wks[i]; } DocGen.HTML.section(sb, "Elbow Plot"); sb.append("<br />"); D3Plot plt = new D3Plot(K, wks_y, "k (Number of clusters)", " log( W_k ) ", "Elbow Plot", true, false); plt.generate(sb); float[] gs = new float[ks]; String[] names = new String[ks]; for (int i = 0; i < gs.length; ++i) { names[i] = "k = " + (i+1); gs[i] = (float)gap_stats[i]; } DocGen.HTML.section(sb, "Gap Statistics"); sb.append("<br />"); DocGen.HTML.graph(sb, "gapstats", "g_varimp", DocGen.HTML.toJSArray(new StringBuilder(), names, null, gap_stats.length), DocGen.HTML.toJSArray(new StringBuilder(), gs , null, gap_stats.length) ); DocGen.HTML.section(sb, "Gap Statistics Less Standard Errors"); sb.append("<br />"); float[] new_gs = new float[gs.length]; for (int i = 0; i < gs.length; ++i) { new_gs[i] = (float) (gs[i] - sks[i]); } DocGen.HTML.graph(sb, "g_minus_err", "g_varimp", DocGen.HTML.toJSArray(new StringBuilder(), names, null, gap_stats.length), DocGen.HTML.toJSArray(new StringBuilder(), new_gs , null, gap_stats.length) ); } int compute_k_best() { double[] gaps = gaps(); double[] log_wks = wks(); double[] sks = sk(); int kmin = -1; for (int i = 0; i < gaps.length - 1; ++i) { int cur_k = i + 1; if (gaps[cur_k] == 0) { kmin = 0; k_best = 1; //= kmin; break; } if (i == gaps.length - 1) { kmin = cur_k; k_best = kmin; break; } if (gaps[i] >= (gaps[i + 1] - sks[i + 1])) { kmin = cur_k; k_best = kmin; break; } } if (kmin <= 0) k_best = 1; if (log_wks[log_wks.length - 1] != 0) { if (kmin > 1) k_best = kmin; } else { if (kmin > 1) k_best = kmin; } if (k_best <= 0) k_best = (int)Double.NaN; if (k_best == 0) k_best = 1; return kmin; } }