package hex.nb;
import hex.FrameTask.DataInfo;
import water.*;
import water.api.DocGen;
import water.fvec.*;
import water.util.RString;
import water.util.Utils;
/**
* Naive Bayes
* This is an algorithm for computing the conditional a-posterior probabilities of a categorical
* response from independent predictors using Bayes rule.
* <a href = "http://en.wikipedia.org/wiki/Naive_Bayes_classifier">Naive Bayes on Wikipedia</a>
* <a href = "http://cs229.stanford.edu/notes/cs229-notes2.pdf">Lecture Notes by Andrew Ng</a>
* @author anqi_fu
*
*/
public class NaiveBayes extends Job.ModelJobWithoutClassificationField {
static final int API_WEAVER = 1;
static public DocGen.FieldDoc[] DOC_FIELDS;
static final String DOC_GET = "naive bayes";
@API(help = "Laplace smoothing parameter", filter = Default.class, lmin = 0, lmax = 100000, json = true)
public int laplace = 0;
@API(help = "Min. standard deviation to use for observations with not enough data", filter = Default.class, dmin = 1e-10, json = true)
public double min_std_dev = 1e-3;
@API(help = "Drop columns with more than 20% missing values", filter = Default.class)
public boolean drop_na_cols = true;
@Override protected void execImpl() {
long before = System.currentTimeMillis();
Frame fr = DataInfo.prepareFrame(source, response, ignored_cols, false, true /*drop const*/, drop_na_cols);
DataInfo dinfo = new DataInfo(fr, 1, false, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE);
NBTask tsk = new NBTask(this, dinfo).doAll(dinfo._adaptedFrame);
NBModel myModel = buildModel(dinfo, tsk, laplace, min_std_dev);
myModel.start_training(before);
myModel.stop_training();
myModel.delete_and_lock(self());
myModel.unlock(self());
}
@Override protected void init() {
super.init();
if(!response.isEnum())
throw new IllegalArgumentException("Response must be a categorical column");
if (laplace < 0) throw new IllegalArgumentException("Laplace smoothing must be an integer >= 0.");
if (min_std_dev <= 1e-10) throw new IllegalArgumentException("Min. standard deviation must be at least 1e-10.");
}
@Override protected Response redirect() {
return NBProgressPage.redirect(this, self(), dest());
}
public static String link(Key src_key, String content) {
RString rs = new RString("<a href='/2/NaiveBayes.query?%key_param=%$key'>%content</a>");
rs.replace("key_param", "source");
rs.replace("key", src_key.toString());
rs.replace("content", content);
return rs.toString();
}
public NBModel buildModel(DataInfo dinfo, NBTask tsk, double laplace, double min_std_dev) {
logStart();
double[] pprior = tsk._rescnt.clone();
double[][][] pcond = tsk._jntcnt.clone();
String[][] domains = dinfo._adaptedFrame.domains();
// A-priori probability of response y
for(int i = 0; i < pprior.length; i++)
pprior[i] = (pprior[i] + laplace)/(tsk._nobs + tsk._nres*laplace);
// pprior[i] = pprior[i]/tsk._nobs; // Note: R doesn't apply laplace smoothing to priors, even though this is textbook definition
// Probability of categorical predictor x_j conditional on response y
for(int col = 0; col < dinfo._cats; col++) {
assert pcond[col].length == tsk._nres;
for(int i = 0; i < pcond[col].length; i++) {
for(int j = 0; j < pcond[col][i].length; j++)
pcond[col][i][j] = (pcond[col][i][j] + laplace)/(tsk._rescnt[i] + domains[col].length*laplace);
}
}
// Mean and standard deviation of numeric predictor x_j for every level of response y
for(int col = 0; col < dinfo._nums; col++) {
for(int i = 0; i < pcond[0].length; i++) {
int cidx = dinfo._cats + col;
double num = tsk._rescnt[i];
double pmean = pcond[cidx][i][0]/num;
pcond[cidx][i][0] = pmean;
// double pvar = pcond[cidx][i][1]/num - pmean*pmean;
double pvar = pcond[cidx][i][1]/(num - 1) - pmean*pmean*num/(num - 1);
pcond[cidx][i][1] = Math.sqrt(pvar);
}
}
Key dataKey = input("source") == null ? null : Key.make(input("source"));
return new NBModel(destination_key, dataKey, dinfo, tsk, pprior, pcond, laplace, min_std_dev);
}
// Note: NA handling differs from R for efficiency purposes
// R's method: For each predictor x_j, skip counting that row for p(x_j|y) calculation if x_j = NA. If response y = NA, skip counting row entirely in all calculations
// H2O's method: Just skip all rows where any x_j = NA or y = NA. Should be more memory-efficient, but results incomparable with R.
public static class NBTask extends MRTask2<NBTask> {
final Job _job;
final protected DataInfo _dinfo;
final int _nres; // Number of levels for the response y
public int _nobs; // Number of rows counted in calculation
public double[] _rescnt; // Count of each level in the response
public double[][][] _jntcnt; // For each categorical predictor, joint count of response and predictor levels
// For each numeric predictor, sum of entries for every response level
public NBTask(Job job, DataInfo dinfo) {
_job = job;
_dinfo = dinfo;
_nobs = 0;
String[][] domains = dinfo._adaptedFrame.domains();
int ncol = dinfo._adaptedFrame.numCols();
assert ncol-1 == dinfo._nums + dinfo._cats; // ncol-1 because we drop response col
_nres = domains[ncol-1].length;
_rescnt = new double[_nres];
_jntcnt = new double[ncol-1][][];
for(int i = 0; i < _jntcnt.length; i++) {
int ncnt = domains[i] == null ? 2 : domains[i].length;
_jntcnt[i] = new double[_nres][ncnt];
}
}
@Override public void map(Chunk[] chks) {
int res_idx = chks.length - 1;
Chunk res = chks[res_idx];
OUTER:
for(int row = 0; row < chks[0]._len; row++) {
// Skip row if any entries in it are NA
for(int col = 0; col < chks.length; col++) {
if(chks[col].isNA0(row)) continue OUTER;
}
// Record joint counts of categorical predictors and response
int rlevel = (int)res.at0(row);
for(int col = 0; col < _dinfo._cats; col++) {
int plevel = (int)chks[col].at0(row);
_jntcnt[col][rlevel][plevel]++;
}
// Record sum for each pair of numerical predictors and response
for(int col = 0; col < _dinfo._nums; col++) {
int cidx = _dinfo._cats + col;
double x = chks[cidx].at0(row);
_jntcnt[cidx][rlevel][0] += x;
_jntcnt[cidx][rlevel][1] += x*x;
}
_rescnt[rlevel]++;
_nobs++;
}
}
@Override public void reduce(NBTask nt) {
_nobs += nt._nobs;
Utils.add(_rescnt, nt._rescnt);
for(int col = 0; col < _jntcnt.length; col++)
_jntcnt[col] = Utils.add(_jntcnt[col], nt._jntcnt[col]);
}
}
}