package hex.glm; import hex.ConfusionMatrix; import hex.glm.GLMParams.Family; import water.Iced; import water.Key; import water.api.AUC; import water.api.DocGen; import water.api.Request.API; import water.util.ModelUtils; /** * Class for GLMValidation. * * @author tomasnykodym * */ public class GLMValidation extends Iced { static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code. @API(help="") double null_deviance; @API(help="") double residual_deviance; @API(help="") long nobs; @API(help="best decision threshold") float best_threshold; @API(help="") double auc = Double.NaN; @API(help="cross validation models") Key [] xval_models; @API(help="AIC") double aic;// internal aic used only for poisson family! @API(help="internal aic used only for poisson family!") private double _aic2;// internal aic used only for poisson family! @API(help="") final Key dataKey; @API(help="Decision thresholds used to generare confuion matrices, AUC and to find the best thresholds based on user criteria") public final float [] thresholds; @API(help="") ConfusionMatrix [] _cms; @API(help="") final GLMParams _glm; @API(help="") final private int _rank; public static class GLMXValidation extends GLMValidation { static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code. public GLMXValidation(GLMModel mainModel, GLMModel [] xvalModels, GLMValidation [] xvals, double lambda, long nobs, float [] thresholds) { super(mainModel._dataKey, mainModel.glm, mainModel.rank(lambda),thresholds); xval_models = new Key[xvalModels.length]; for(int i = 0; i < xval_models.length; ++i) xval_models[i] = xvalModels[i]._key; double t = 0; for(int i = 0; i < xvalModels.length; ++i){ add(xvals[i]); t += xvals[i].best_threshold; } computeAUC(); computeAIC(); best_threshold = (float)(t/xvalModels.length); this.nobs = nobs; } } public GLMValidation(Key dataKey, GLMParams glm, int rank){ this(dataKey, glm, rank,glm.family == Family.binomial?ModelUtils.DEFAULT_THRESHOLDS:null); } public GLMValidation(Key dataKey, GLMParams glm, int rank, float [] thresholds){ _rank = rank; _glm = glm; if(_glm.family == Family.binomial){ _cms = new ConfusionMatrix[thresholds.length]; for(int i = 0; i < _cms.length; ++i) _cms[i] = new ConfusionMatrix(2); } this.dataKey = dataKey; this.thresholds = thresholds; } public static Key makeKey(){return Key.make("__GLMValidation_" + Key.make());} public void add(double yreal, double ymodel){ if(_glm.family == Family.binomial) // classification -> update confusion matrix too for(int i = 0; i < thresholds.length; ++i) _cms[i].add((int)yreal, (ymodel >= thresholds[i])?1:0); residual_deviance += _glm.deviance(yreal, ymodel); ++nobs; if( _glm.family == Family.poisson ) { // aic for poisson long y = Math.round(yreal); double logfactorial = 0; for( long i = 2; i <= y; ++i ) logfactorial += Math.log(i); _aic2 += (yreal * Math.log(ymodel) - logfactorial - ymodel); } } public void add(GLMValidation v){ residual_deviance += v.residual_deviance; nobs += v.nobs; _aic2 += v._aic2; if(_cms == null)_cms = v._cms; else for(int i = 0; i < _cms.length; ++i)_cms[i].add(v._cms[i]); } public final double residualDeviance(){return residual_deviance;} public final double nullDeviance(){return null_deviance;} public final long resDOF(){return nobs - _rank -1;} public double auc(){return auc;} public double aic(){return aic;} protected void computeAIC(){ aic = 0; switch( _glm.family ) { case gaussian: aic = nobs * (Math.log(residual_deviance / nobs * 2 * Math.PI) + 1) + 2; break; case binomial: aic = residual_deviance; break; case poisson: aic = -2*_aic2; break; // aic is set during the validation task case gamma: case tweedie: aic = Double.NaN; break; default: assert false : "missing implementation for family " + _glm.family; } aic += 2*_rank; } protected void computeAUC(){ if(_glm.family == Family.binomial){ for(ConfusionMatrix cm:_cms)cm.reComputeErrors(); AUC auc = new AUC(_cms,thresholds,/*TODO: add CM domain*/null); this.auc = auc.data().AUC(); best_threshold = auc.data().threshold(); } } @Override public String toString(){ return " res_dev = " + residual_deviance + ", auc = " + auc(); } /** * Computes area under the ROC curve. The ROC curve is computed from the confusion matrices * (there is one for each computed threshold). Area under this curve is then computed as a sum * of areas of trapezoids formed by each neighboring points. * * @return estimate of the area under ROC curve of this classifier. */ double[] tprs; double[] fprs; private double trapeziod_area(double x1, double x2, double y1, double y2) { double base = Math.abs(x1 - x2); double havg = 0.5 * (y1 + y2); return base * havg; } }