package hex.nb; import hex.FrameTask.DataInfo; import hex.nb.NaiveBayes.NBTask; import org.apache.commons.math3.distribution.NormalDistribution; import water.Key; import water.Model; import water.Request2; import water.api.DocGen; import water.api.Predict; import water.api.Request.API; import water.api.RequestBuilders.ElementBuilder; /** * FIXME comment please */ public class NBModel extends Model { static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code. @API(help = "Class counts of the dependent variable") final double[] rescnt; @API(help = "Class distribution of the dependent variable") final double[] pprior; @API(help = "For every predictor variable, a table giving, for each attribute level, the conditional probabilities given the target class") final double[][][] pcond; @API(help = "Number of categorical predictor variables") final int ncats; @API(help = "Number of numeric predictor variables") final int nnums; @API(help = "Laplace smoothing parameter") final double laplace; @API(help = "Min. standard deviation to use for observations with not enough data") final double min_std_dev; @API(help = "Model parameters", json = true) private Request2 job; @Override public final NaiveBayes get_params() { return (NaiveBayes)job; } @Override public final Request2 job() { return job; } public NBModel(Key selfKey, Key dataKey, DataInfo dinfo, NBTask tsk, double[] pprior, double[][][] pcond, double laplace, double min_std_dev) { super(selfKey, dataKey, dinfo._adaptedFrame, /* priorClassDistribution */ null); this.rescnt = tsk._rescnt; this.job= tsk._job; this.pprior = pprior; this.pcond = pcond; this.ncats = dinfo._cats; this.nnums = dinfo._nums; this.laplace = laplace; this.min_std_dev = min_std_dev; } public double[] pprior() { return pprior; } public double[][][] pcond() { return pcond; } // Note: For small probabilities, product may end up zero due to underflow error. Can circumvent by taking logs. @Override protected float[] score0(double[] data, float[] preds) { double denom = 0; assert preds.length == (pprior.length + 1); // Note: First column of preds is predicted response class // Compute joint probability of predictors for every response class for(int rlevel = 0; rlevel < pprior.length; rlevel++) { double num = 1; for(int col = 0; col < ncats; col++) { if(Double.isNaN(data[col])) continue; // Skip predictor in joint x_1,...,x_m if NA int plevel = (int)data[col]; num *= pcond[col][rlevel][plevel]; // p(x|y) = \Pi_{j = 1}^m p(x_j|y) } // For numeric predictors, assume Gaussian distribution with sample mean and variance from model for(int col = ncats; col < data.length; col++) { if(Double.isNaN(data[col])) continue; // Two ways to get non-zero std deviation HEX-1852 // double stddev = pcond[col][rlevel][1] > 0 ? pcond[col][rlevel][1] : min_std_dev; //only use the placeholder for critically low data double stddev = Math.max(pcond[col][rlevel][1], min_std_dev); // more stable for almost constant data double mean = pcond[col][rlevel][0]; double x = data[col]; num *= Math.exp(-((x-mean)*(x-mean)/(2.*stddev*stddev)))/stddev/Math.sqrt(2.*Math.PI); // faster // num *= new NormalDistribution(mean, stddev).density(data[col]); //slower } num *= pprior[rlevel]; // p(x,y) = p(x|y)*p(y) denom += num; // p(x) = \Sum_{levels of y} p(x,y) preds[rlevel+1] = (float)num; } // Select class with highest conditional probability float max = -1; for(int i = 1; i < preds.length; i++) { preds[i] /= denom; // p(y|x) = p(x,y)/p(x) if(preds[i] > max) { max = preds[i]; preds[0] = i-1; } } return preds; } @Override public String toString(){ StringBuilder sb = new StringBuilder("Naive Bayes Model (key=" + _key + " , trained on " + _dataKey + "):\n"); return sb.toString(); } public void generateHTML(String title, StringBuilder sb) { if(title != null && !title.isEmpty()) DocGen.HTML.title(sb, title); DocGen.HTML.paragraph(sb, "Model Key: " + _key); sb.append("<div class='alert'>Actions: " + Predict.link(_key, "Predict on dataset") + ", " + NaiveBayes.link(_dataKey, "Compute new model") + "</div>"); DocGen.HTML.section(sb, "A-Priori Probabilities"); sb.append("<span style='display: inline-block;'>"); sb.append("<table class='table table-striped table-bordered'>"); // Domain of the response variable String[] resdom = _domains[_domains.length-1]; sb.append("<tr>"); for(int i = 0; i < resdom.length; i++) sb.append("<th>").append(resdom[i]).append("</th>"); sb.append("</tr>"); // Display table of a-priori response probabilities sb.append("<tr>"); for(int i = 0; i < pprior.length; i++) sb.append("<td>").append(ElementBuilder.format(pprior[i])).append("</td>"); sb.append("</tr>"); sb.append("</table></span>"); DocGen.HTML.section(sb, "Conditional Probabilities"); // Display table of conditional probabilities for categorical predictors for(int col = 0; col < ncats; col++) { DocGen.HTML.paragraph(sb, "Column: " + _names[col]); sb.append("<span style='display: inline-block;'>"); sb.append("<table class='table table-striped table-bordered'>"); // Domain of the predictor variable sb.append("<tr>"); sb.append("<th>").append("Response/Predictor").append("</th>"); for(int i = 0; i < _domains[col].length; i++) sb.append("<th>").append(_domains[col][i]).append("</th>"); sb.append("</tr>"); // For each predictor, display table of conditional probabilities for(int r = 0; r < pcond[col].length; r++) { sb.append("<tr>"); sb.append("<th>").append(resdom[r]).append("</th>"); for(int c = 0; c < pcond[col][r].length; c++) { double e = pcond[col][r][c]; sb.append("<td>").append(ElementBuilder.format(e)).append("</td>"); } sb.append("</tr>"); } sb.append("</table></span>"); } // Display table of statistics for numeric predictors for(int col = ncats; col < ncats + nnums; col++) { DocGen.HTML.paragraph(sb, "Column: " + _names[col]); sb.append("<span style='display: inline-block;'>"); sb.append("<table class='table table-striped table-bordered'>"); // Labels for the predictor variable columns sb.append("<tr>"); sb.append("<th>").append("Response/Predictor").append("</th>"); sb.append("<th>").append("Mean").append("</th>"); sb.append("<th>").append("Standard Deviation").append("</th>"); sb.append("</tr>"); // For each predictor, display mean and standard deviation within every response level for(int r = 0; r < pcond[col].length; r++) { sb.append("<tr>"); sb.append("<th>").append(resdom[r]).append("</th>"); double pmean = pcond[col][r][0]; double psdev = pcond[col][r][1]; sb.append("<td>").append(ElementBuilder.format(pmean)).append("</td>"); sb.append("<td>").append(ElementBuilder.format(psdev)).append("</td>"); sb.append("</tr>"); } sb.append("</table></span>"); } } }