package hex; import Jama.Matrix; import java.util.Arrays; import jsr166y.ForkJoinTask; import jsr166y.RecursiveAction; import hex.FrameTask.DataInfo; import water.DKV; import water.Futures; import water.Job; import water.Key; import water.MemoryManager; import water.Model; import water.Request2; import water.api.CoxPHProgressPage; import water.api.DocGen; import water.fvec.Frame; import water.fvec.Vec; import water.fvec.Vec.CollectDomain; import water.util.Utils; public class CoxPH extends Job { @API(help="Data Frame", required=true, filter=Default.class, json=true) public Frame source; @API(help="Start Time Column", required=false, filter=CoxPHVecSelect.class, json=true) public Vec start_column = null; @API(help="Stop Time Column", required=true, filter=CoxPHVecSelect.class, json=true) public Vec stop_column; @API(help="Event Column", required=true, filter=CoxPHVecSelect.class, json=true) public Vec event_column; @API(help="X Columns", required=true, filter=CoxPHMultiVecSelect.class, json=true) public int[] x_columns; @API(help="Weights Column", required=false, filter=CoxPHVecSelect.class, json=true) public Vec weights_column = null; @API(help="Offset Columns", required=false, filter=CoxPHMultiVecSelect.class, json=true) public int[] offset_columns; @API(help="Method for Handling Ties", required=true, filter=Default.class, json=true) public CoxPHTies ties = CoxPHTies.efron; @API(help="coefficient starting value", required=true, filter=Default.class, json=true) public double init = 0; @API(help="minimum log-relative error", required=true, filter=Default.class, json=true) public double lre_min = 9; @API(help="maximum number of iterations", required=true, filter=Default.class, json=true) public int iter_max = 20; private class CoxPHVecSelect extends VecSelect { CoxPHVecSelect() { super("source"); } } private class CoxPHMultiVecSelect extends MultiVecSelect { CoxPHMultiVecSelect() { super("source"); } } public static final int MAX_TIME_BINS = 10000; public static enum CoxPHTies { efron, breslow } public static double[][] malloc2DArray(final int d1, final int d2) { final double[][] array = new double[d1][]; for (int j = 0; j < d1; ++j) array[j] = MemoryManager.malloc8d(d2); return array; } public static double[][][] malloc3DArray(final int d1, final int d2, final int d3) { final double[][][] array = new double[d1][d2][]; for (int j = 0; j < d1; ++j) for (int k = 0; k < d2; ++k) array[j][k] = MemoryManager.malloc8d(d3); return array; } public static class CoxPHModel extends Model implements Job.Progress { static final int API_WEAVER = 1; // This file has auto-generated doc & JSON fields static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from auto-generated code. @API(help = "model parameters", json = true) final private CoxPH parameters; @API(help="Input data info") DataInfo data_info; @API(help = "names of coefficients") String[] coef_names; @API(help = "coefficients") double[] coef; @API(help = "exp(coefficients)") double[] exp_coef; @API(help = "exp(-coefficients)") double[] exp_neg_coef; @API(help = "se(coefficients)") double[] se_coef; @API(help = "z-score") double[] z_coef; @API(help = "var(coefficients)") double[][] var_coef; @API(help = "null log-likelihood") double null_loglik; @API(help = "log-likelihood") double loglik; @API(help = "log-likelihood test stat") double loglik_test; @API(help = "Wald test stat") double wald_test; @API(help = "Score test stat") double score_test; @API(help = "R-square") double rsq; @API(help = "Maximum R-square") double maxrsq; @API(help = "gradient", json = false) double[] gradient; @API(help = "Hessian", json = false) double[][] hessian; @API(help = "log relative error") double lre; @API(help = "number of iterations") int iter; @API(help = "x weighted mean vector for categorical variables") double[] x_mean_cat; @API(help = "x weighted mean vector for numeric variables") double[] x_mean_num; @API(help = "unweighted mean vector for numeric offsets") double[] mean_offset; @API(help = "names of offsets") String[] offset_names; @API(help = "n") long n; @API(help = "number of rows with missing values") long n_missing; @API(help = "total events") long total_event; @API(help = "minimum time") long min_time; @API(help = "maximum time") long max_time; @API(help = "time") long[] time; @API(help = "number at risk") double[] n_risk; @API(help = "number of events") double[] n_event; @API(help = "number of censored obs") double[] n_censor; @API(help = "baseline cumulative hazard") double[] cumhaz_0; @API(help = "component of var(cumhaz)", json = false) double[] var_cumhaz_1; @API(help = "component of var(cumhaz)", json = false) double[][] var_cumhaz_2; public CoxPHModel(CoxPH job, Key selfKey, Key dataKey, Frame fr, float[] priorClassDist) { super(selfKey, dataKey, fr, priorClassDist); parameters = (CoxPH) job.clone(); } @Override public final CoxPH get_params() { return parameters; } @Override public final Request2 job() { return get_params(); } @Override public float progress() { return (float) iter / (float) get_params().iter_max; } // Following three overrides created for use in super.scoreImpl /* @Override public String[] classNames() { final String[] names = new String[nclasses()]; for (int i = 0; i < time.length; ++i) { final long t = time[i]; names[i] = "cumhaz_" + t; names[i + time.length] = "se_cumhaz_" + t; } return names; } @Override public boolean isClassifier() { return false; } @Override public int nclasses() { return 2 * time.length; } @Override protected float[] score0(double[] data, float[] preds) { final int n_offsets = (parameters.offset_columns == null) ? 0 : parameters.offset_columns.length; final int n_time = time.length; final int n_coef = coef.length; final int n_cats = data_info._cats; final int n_nums = data_info._nums; final int n_data = n_cats + n_nums; final int n_full = n_coef + n_offsets; final int numStart = data_info.numStart(); boolean catsAllNA = true; boolean catsHasNA = false; boolean numsHasNA = false; for (int j = 0; j < n_cats; ++j) { catsAllNA &= Double.isNaN(data[j]); catsHasNA |= Double.isNaN(data[j]); } for (int j = n_cats; j < n_data; ++j) numsHasNA |= Double.isNaN(data[j]); if (numsHasNA || (catsHasNA && !catsAllNA)) { for (int i = 1; i <= 2 * n_time; ++i) preds[i] = Float.NaN; } else { double[] full_data = MemoryManager.malloc8d(n_full); for (int j = 0; j < n_cats; ++j) if (Double.isNaN(data[j])) { final int kst = data_info._catOffsets[j]; final int klen = data_info._catOffsets[j+1] - kst; System.arraycopy(x_mean_cat, kst, full_data, kst, klen); } else if (data[j] != 0) full_data[data_info._catOffsets[j] + (int) (data[j] - 1)] = 1; for (int j = 0; j < n_nums; ++j) full_data[numStart + j] = data[n_cats + j] - data_info._normSub[j]; double logRisk = 0; for (int j = 0; j < n_coef; ++j) logRisk += full_data[j] * coef[j]; for (int j = n_coef; j < full_data.length; ++j) logRisk += full_data[j]; final double risk = Math.exp(logRisk); for (int t = 0; t < n_time; ++t) preds[t + 1] = (float) (risk * cumhaz_0[t]); for (int t = 0; t < n_time; ++t) { final double cumhaz_0_t = cumhaz_0[t]; double var_cumhaz_2_t = 0; for (int j = 0; j < n_coef; ++j) { double sum = 0; for (int k = 0; k < n_coef; ++k) sum += var_coef[j][k] * (full_data[k] * cumhaz_0_t - var_cumhaz_2[t][k]); var_cumhaz_2_t += (full_data[j] * cumhaz_0_t - var_cumhaz_2[t][j]) * sum; } preds[t + 1 + n_time] = (float) (risk * Math.sqrt(var_cumhaz_1[t] + var_cumhaz_2_t)); } } preds[0] = Float.NaN; return preds; } */ @Override protected float[] score0(double[] data, float[] preds) { final int n_offsets = (parameters.offset_columns == null) ? 0 : parameters.offset_columns.length; final int n_cats = data_info._cats; final int n_nums = data_info._nums; final int n_data = n_cats + n_nums; final int numStart = data_info.numStart(); final int n_non_offsets = n_nums - n_offsets; boolean catsAllNA = true; boolean catsHasNA = false; boolean numsHasNA = false; for (int j = 0; j < n_cats; ++j) { catsAllNA &= Double.isNaN(data[j]); catsHasNA |= Double.isNaN(data[j]); } for (int j = n_cats; j < n_data; ++j) numsHasNA |= Double.isNaN(data[j]); if (numsHasNA || (catsHasNA && !catsAllNA)) { preds[0] = Float.NaN; } else { double logRisk = 0; for (int j = 0; j < n_cats; ++j) { final int k_start = data_info._catOffsets[j]; final int k_end = data_info._catOffsets[j + 1]; if (Double.isNaN(data[j])) for (int k = k_start; k < k_end; ++k) logRisk += x_mean_cat[k] * coef[k]; else if (data[j] != 0) logRisk += coef[k_start + (int) (data[j] - 1)]; } for (int j = 0; j < n_non_offsets; ++j) logRisk += (data[n_cats + j] - data_info._normSub[j]) * coef[numStart + j]; for (int j = n_non_offsets; j < n_nums; ++j) logRisk += (data[n_cats + j] - data_info._normSub[j]); preds[0] = (float) Math.exp(logRisk); } return preds; } protected void initStats(final Frame source, final DataInfo dinfo) { n = source.numRows(); data_info = dinfo; final int n_offsets = (parameters.offset_columns == null) ? 0 : parameters.offset_columns.length; final int n_coef = data_info.fullN() - n_offsets; final String[] coefNames = data_info.coefNames(); coef_names = new String[n_coef]; System.arraycopy(coefNames, 0, coef_names, 0, n_coef); coef = MemoryManager.malloc8d(n_coef); exp_coef = MemoryManager.malloc8d(n_coef); exp_neg_coef = MemoryManager.malloc8d(n_coef); se_coef = MemoryManager.malloc8d(n_coef); z_coef = MemoryManager.malloc8d(n_coef); gradient = MemoryManager.malloc8d(n_coef); hessian = malloc2DArray(n_coef, n_coef); var_coef = malloc2DArray(n_coef, n_coef); x_mean_cat = MemoryManager.malloc8d(n_coef - (data_info._nums - n_offsets)); x_mean_num = MemoryManager.malloc8d(data_info._nums - n_offsets); mean_offset = MemoryManager.malloc8d(n_offsets); offset_names = new String[n_offsets]; System.arraycopy(coefNames, n_coef, offset_names, 0, n_offsets); final Vec start_column = source.vec(source.numCols() - 3); final Vec stop_column = source.vec(source.numCols() - 2); min_time = parameters.start_column == null ? (long) stop_column.min(): (long) start_column.min() + 1; max_time = (long) stop_column.max(); final int n_time = new CollectDomain(stop_column).doAll(stop_column).domain().length; time = MemoryManager.malloc8(n_time); n_risk = MemoryManager.malloc8d(n_time); n_event = MemoryManager.malloc8d(n_time); n_censor = MemoryManager.malloc8d(n_time); cumhaz_0 = MemoryManager.malloc8d(n_time); var_cumhaz_1 = MemoryManager.malloc8d(n_time); var_cumhaz_2 = malloc2DArray(n_time, n_coef); } protected void calcCounts(final CoxPHTask coxMR) { n_missing = n - coxMR.n; n = coxMR.n; for (int j = 0; j < x_mean_cat.length; j++) x_mean_cat[j] = coxMR.sumWeightedCatX[j] / coxMR.sumWeights; for (int j = 0; j < x_mean_num.length; j++) x_mean_num[j] = coxMR._dinfo._normSub[j] + coxMR.sumWeightedNumX[j] / coxMR.sumWeights; System.arraycopy(coxMR._dinfo._normSub, x_mean_num.length, mean_offset, 0, mean_offset.length); int nz = 0; for (int t = 0; t < coxMR.countEvents.length; ++t) { total_event += coxMR.countEvents[t]; if (coxMR.sizeEvents[t] > 0 || coxMR.sizeCensored[t] > 0) { time[nz] = min_time + t; n_risk[nz] = coxMR.sizeRiskSet[t]; n_event[nz] = coxMR.sizeEvents[t]; n_censor[nz] = coxMR.sizeCensored[t]; nz++; } } if (parameters.start_column == null) for (int t = n_risk.length - 2; t >= 0; --t) n_risk[t] += n_risk[t + 1]; } protected double calcLoglik(final CoxPHTask coxMR) { final int n_coef = coef.length; final int n_time = coxMR.sizeEvents.length; double newLoglik = 0; for (int j = 0; j < n_coef; ++j) gradient[j] = 0; for (int j = 0; j < n_coef; ++j) for (int k = 0; k < n_coef; ++k) hessian[j][k] = 0; switch (parameters.ties) { case efron: final double[] newLoglik_t = MemoryManager.malloc8d(n_time); final double[][] gradient_t = malloc2DArray(n_time, n_coef); final double[][][] hessian_t = malloc3DArray(n_time, n_coef, n_coef); ForkJoinTask[] fjts = new ForkJoinTask[n_time]; for (int t = n_time - 1; t >= 0; --t) { final int _t = t; fjts[t] = new RecursiveAction() { @Override protected void compute() { final double sizeEvents_t = coxMR.sizeEvents[_t]; if (sizeEvents_t > 0) { final long countEvents_t = coxMR.countEvents[_t]; final double sumLogRiskEvents_t = coxMR.sumLogRiskEvents[_t]; final double sumRiskEvents_t = coxMR.sumRiskEvents[_t]; final double rcumsumRisk_t = coxMR.rcumsumRisk[_t]; final double avgSize = sizeEvents_t / countEvents_t; newLoglik_t[_t] = sumLogRiskEvents_t; System.arraycopy(coxMR.sumXEvents[_t], 0, gradient_t[_t], 0, n_coef); for (long e = 0; e < countEvents_t; ++e) { final double frac = ((double) e) / ((double) countEvents_t); final double term = rcumsumRisk_t - frac * sumRiskEvents_t; newLoglik_t[_t] -= avgSize * Math.log(term); for (int j = 0; j < n_coef; ++j) { final double djTerm = coxMR.rcumsumXRisk[_t][j] - frac * coxMR.sumXRiskEvents[_t][j]; final double djLogTerm = djTerm / term; gradient_t[_t][j] -= avgSize * djLogTerm; for (int k = 0; k < n_coef; ++k) { final double dkTerm = coxMR.rcumsumXRisk[_t][k] - frac * coxMR.sumXRiskEvents[_t][k]; final double djkTerm = coxMR.rcumsumXXRisk[_t][j][k] - frac * coxMR.sumXXRiskEvents[_t][j][k]; hessian_t[_t][j][k] -= avgSize * (djkTerm / term - (djLogTerm * (dkTerm / term))); } } } } } }; } ForkJoinTask.invokeAll(fjts); for (int t = 0; t < n_time; ++t) newLoglik += newLoglik_t[t]; for (int t = 0; t < n_time; ++t) for (int j = 0; j < n_coef; ++j) gradient[j] += gradient_t[t][j]; for (int t = 0; t < n_time; ++t) for (int j = 0; j < n_coef; ++j) for (int k = 0; k < n_coef; ++k) hessian[j][k] += hessian_t[t][j][k]; break; case breslow: for (int t = n_time - 1; t >= 0; --t) { final double sizeEvents_t = coxMR.sizeEvents[t]; if (sizeEvents_t > 0) { final double sumLogRiskEvents_t = coxMR.sumLogRiskEvents[t]; final double rcumsumRisk_t = coxMR.rcumsumRisk[t]; newLoglik += sumLogRiskEvents_t; newLoglik -= sizeEvents_t * Math.log(rcumsumRisk_t); for (int j = 0; j < n_coef; ++j) { final double dlogTerm = coxMR.rcumsumXRisk[t][j] / rcumsumRisk_t; gradient[j] += coxMR.sumXEvents[t][j]; gradient[j] -= sizeEvents_t * dlogTerm; for (int k = 0; k < n_coef; ++k) hessian[j][k] -= sizeEvents_t * (((coxMR.rcumsumXXRisk[t][j][k] / rcumsumRisk_t) - (dlogTerm * (coxMR.rcumsumXRisk[t][k] / rcumsumRisk_t)))); } } } break; default: throw new IllegalArgumentException("ties method must be either efron or breslow"); } return newLoglik; } protected void calcModelStats(final double[] newCoef, final double newLoglik) { final int n_coef = coef.length; final Matrix inv_hessian = new Matrix(hessian).inverse(); for (int j = 0; j < n_coef; ++j) { for (int k = 0; k <= j; ++k) { final double elem = -inv_hessian.get(j, k); var_coef[j][k] = elem; var_coef[k][j] = elem; } } for (int j = 0; j < n_coef; ++j) { coef[j] = newCoef[j]; exp_coef[j] = Math.exp(coef[j]); exp_neg_coef[j] = Math.exp(- coef[j]); se_coef[j] = Math.sqrt(var_coef[j][j]); z_coef[j] = coef[j] / se_coef[j]; } if (iter == 0) { null_loglik = newLoglik; maxrsq = 1 - Math.exp(2 * null_loglik / n); score_test = 0; for (int j = 0; j < n_coef; ++j) { double sum = 0; for (int k = 0; k < n_coef; ++k) sum += var_coef[j][k] * gradient[k]; score_test += gradient[j] * sum; } } loglik = newLoglik; loglik_test = - 2 * (null_loglik - loglik); rsq = 1 - Math.exp(- loglik_test / n); wald_test = 0; for (int j = 0; j < n_coef; ++j) { double sum = 0; for (int k = 0; k < n_coef; ++k) sum -= hessian[j][k] * (coef[k] - parameters.init); wald_test += (coef[j] - parameters.init) * sum; } } protected void calcCumhaz_0(final CoxPHTask coxMR) { final int n_coef = coef.length; int nz = 0; switch (parameters.ties) { case efron: for (int t = 0; t < coxMR.sizeEvents.length; ++t) { final double sizeEvents_t = coxMR.sizeEvents[t]; final double sizeCensored_t = coxMR.sizeCensored[t]; if (sizeEvents_t > 0 || sizeCensored_t > 0) { final long countEvents_t = coxMR.countEvents[t]; final double sumRiskEvents_t = coxMR.sumRiskEvents[t]; final double rcumsumRisk_t = coxMR.rcumsumRisk[t]; final double avgSize = sizeEvents_t / countEvents_t; cumhaz_0[nz] = 0; var_cumhaz_1[nz] = 0; for (int j = 0; j < n_coef; ++j) var_cumhaz_2[nz][j] = 0; for (long e = 0; e < countEvents_t; ++e) { final double frac = ((double) e) / ((double) countEvents_t); final double haz = 1 / (rcumsumRisk_t - frac * sumRiskEvents_t); final double haz_sq = haz * haz; cumhaz_0[nz] += avgSize * haz; var_cumhaz_1[nz] += avgSize * haz_sq; for (int j = 0; j < n_coef; ++j) var_cumhaz_2[nz][j] += avgSize * ((coxMR.rcumsumXRisk[t][j] - frac * coxMR.sumXRiskEvents[t][j]) * haz_sq); } nz++; } } break; case breslow: for (int t = 0; t < coxMR.sizeEvents.length; ++t) { final double sizeEvents_t = coxMR.sizeEvents[t]; final double sizeCensored_t = coxMR.sizeCensored[t]; if (sizeEvents_t > 0 || sizeCensored_t > 0) { final double rcumsumRisk_t = coxMR.rcumsumRisk[t]; final double cumhaz_0_nz = sizeEvents_t / rcumsumRisk_t; cumhaz_0[nz] = cumhaz_0_nz; var_cumhaz_1[nz] = sizeEvents_t / (rcumsumRisk_t * rcumsumRisk_t); for (int j = 0; j < n_coef; ++j) var_cumhaz_2[nz][j] = (coxMR.rcumsumXRisk[t][j] / rcumsumRisk_t) * cumhaz_0_nz; nz++; } } break; default: throw new IllegalArgumentException("ties method must be either efron or breslow"); } for (int t = 1; t < cumhaz_0.length; ++t) { cumhaz_0[t] = cumhaz_0[t - 1] + cumhaz_0[t]; var_cumhaz_1[t] = var_cumhaz_1[t - 1] + var_cumhaz_1[t]; for (int j = 0; j < n_coef; ++j) var_cumhaz_2[t][j] = var_cumhaz_2[t - 1][j] + var_cumhaz_2[t][j]; } } public Frame makeSurvfit(final Key key, double x_new) { // FIXME int j = 0; if (Double.isNaN(x_new)) x_new = data_info._normSub[j]; final int n_time = time.length; final Vec[] vecs = Vec.makeNewCons((long) n_time, 4, 0, null); final Vec timevec = vecs[0]; final Vec cumhaz = vecs[1]; final Vec se_cumhaz = vecs[2]; final Vec surv = vecs[3]; final double x_centered = x_new - data_info._normSub[j]; final double risk = Math.exp(coef[j] * x_centered); for (int t = 0; t < n_time; ++t) timevec.set(t, time[t]); for (int t = 0; t < n_time; ++t) { final double cumhaz_1 = risk * cumhaz_0[t]; cumhaz.set(t, cumhaz_1); surv.set(t, Math.exp(-cumhaz_1)); } for (int t = 0; t < n_time; ++t) { final double gamma = x_centered * cumhaz_0[t] - var_cumhaz_2[t][j]; se_cumhaz.set(t, risk * Math.sqrt(var_cumhaz_1[t] + (gamma * var_coef[j][j] * gamma))); } final Frame fr = new Frame(key, new String[] {"time", "cumhaz", "se_cumhaz", "surv"}, vecs); final Futures fs = new Futures(); DKV.put(key, fr, fs); fs.blockForPending(); return fr; } public void generateHTML(final String title, final StringBuilder sb) { DocGen.HTML.title(sb, title); sb.append("<h4>Data</h4>"); sb.append("<table class='table table-striped table-bordered table-condensed'><col width=\"25%\"><col width=\"75%\">"); sb.append("<tr><th>Number of Complete Cases</th><td>"); sb.append(n); sb.append("</td></tr>"); sb.append("<tr><th>Number of Non Complete Cases</th><td>"); sb.append(n_missing); sb.append("</td></tr>"); sb.append("<tr><th>Number of Events in Complete Cases</th><td>");sb.append(total_event);sb.append("</td></tr>"); sb.append("</table>"); sb.append("<h4>Coefficients</h4>"); sb.append("<table class='table table-striped table-bordered table-condensed'>"); sb.append("<tr><th></th><th>coef</th><th>exp(coef)</th><th>se(coef)</th><th>z</th></tr>"); for (int j = 0; j < coef.length; ++j) { sb.append("<tr><th>"); sb.append(coef_names[j]);sb.append("</th><td>");sb.append(coef[j]); sb.append("</td><td>"); sb.append(exp_coef[j]); sb.append("</td><td>");sb.append(se_coef[j]);sb.append("</td><td>"); sb.append(z_coef[j]); sb.append("</td></tr>"); } sb.append("</table>"); sb.append("<h4>Model Statistics</h4>"); sb.append("<table class='table table-striped table-bordered table-condensed'><col width=\"15%\"><col width=\"85%\">"); sb.append("<tr><th>Rsquare</th><td>");sb.append(String.format("%.3f", rsq)); sb.append(" (max possible = "); sb.append(String.format("%.3f", maxrsq));sb.append(")</td></tr>"); sb.append("<tr><th>Likelihood ratio test</th><td>");sb.append(String.format("%.2f", loglik_test)); sb.append(" on ");sb.append(coef.length);sb.append(" df</td></tr>"); sb.append("<tr><th>Wald test </th><td>");sb.append(String.format("%.2f", wald_test)); sb.append(" on ");sb.append(coef.length);sb.append(" df</td></tr>"); sb.append("<tr><th>Score (logrank) test </th><td>");sb.append(String.format("%.2f", score_test)); sb.append(" on ");sb.append(coef.length);sb.append(" df</td></tr>"); sb.append("</table>"); } public void toJavaHtml(StringBuilder sb) { } } private CoxPHModel model; @Override protected void init() { super.init(); if ((start_column != null) && !start_column.isInt()) throw new IllegalArgumentException("start time must be null or of type integer"); if (!stop_column.isInt()) throw new IllegalArgumentException("stop time must be of type integer"); if (!event_column.isInt() && !event_column.isEnum()) throw new IllegalArgumentException("event must be of type integer or factor"); if ((event_column.isInt() && (event_column.min() == event_column.max())) || (event_column.isEnum() && (event_column.cardinality() < 2))) throw new IllegalArgumentException("event column contains less than two distinct values"); if (Double.isNaN(lre_min) || lre_min <= 0) throw new IllegalArgumentException("lre_min must be a positive number"); if (iter_max < 1) throw new IllegalArgumentException("iter_max must be a positive integer"); final long min_time = (start_column == null) ? (long) stop_column.min() : (long) start_column.min() + 1; final int n_time = (int) (stop_column.max() - min_time + 1); if (n_time < 1) throw new IllegalArgumentException("start times must be strictly less than stop times"); if (n_time > MAX_TIME_BINS) throw new IllegalArgumentException("number of distinct stop times is " + n_time + "; maximum number allowed is " + MAX_TIME_BINS); source = getSubframe(); int n_resp = 2; if (weights_column != null) n_resp++; if (start_column != null) n_resp++; final DataInfo dinfo = new DataInfo(source, n_resp, false, false, DataInfo.TransformType.DEMEAN); model = new CoxPHModel(this, dest(), source._key, source, null); model.initStats(source, dinfo); } @Override protected void execImpl() { final DataInfo dinfo = model.data_info; final int n_offsets = (model.parameters.offset_columns == null) ? 0 : model.parameters.offset_columns.length; final int n_coef = dinfo.fullN() - n_offsets; final double[] step = MemoryManager.malloc8d(n_coef); final double[] oldCoef = MemoryManager.malloc8d(n_coef); final double[] newCoef = MemoryManager.malloc8d(n_coef); Arrays.fill(step, Double.NaN); Arrays.fill(oldCoef, Double.NaN); for (int j = 0; j < n_coef; ++j) newCoef[j] = init; double oldLoglik = - Double.MAX_VALUE; final int n_time = (int) (model.max_time - model.min_time + 1); final boolean has_start_column = (model.parameters.start_column != null); final boolean has_weights_column = (model.parameters.weights_column != null); for (int i = 0; i <= iter_max; ++i) { model.iter = i; final CoxPHTask coxMR = new CoxPHTask(self(), dinfo, newCoef, model.min_time, n_time, n_offsets, has_start_column, has_weights_column).doAll(dinfo._adaptedFrame); final double newLoglik = model.calcLoglik(coxMR); if (newLoglik > oldLoglik) { if (i == 0) model.calcCounts(coxMR); model.calcModelStats(newCoef, newLoglik); model.calcCumhaz_0(coxMR); if (newLoglik == 0) model.lre = - Math.log10(Math.abs(oldLoglik - newLoglik)); else model.lre = - Math.log10(Math.abs((oldLoglik - newLoglik) / newLoglik)); if (model.lre >= lre_min) break; Arrays.fill(step, 0); for (int j = 0; j < n_coef; ++j) for (int k = 0; k < n_coef; ++k) step[j] -= model.var_coef[j][k] * model.gradient[k]; for (int j = 0; j < n_coef; ++j) if (Double.isNaN(step[j]) || Double.isInfinite(step[j])) break; oldLoglik = newLoglik; System.arraycopy(newCoef, 0, oldCoef, 0, oldCoef.length); } else { for (int j = 0; j < n_coef; ++j) step[j] /= 2; } for (int j = 0; j < n_coef; ++j) newCoef[j] = oldCoef[j] - step[j]; } final Futures fs = new Futures(); DKV.put(dest(), model, fs); fs.blockForPending(); } @Override protected Response redirect() { return CoxPHProgressPage.redirect(this, self(), dest()); } private Frame getSubframe() { final boolean use_start_column = (start_column != null); final boolean use_weights_column = (weights_column != null); final int x_ncol = x_columns.length; final int offset_ncol = offset_columns == null ? 0 : offset_columns.length; int ncol = x_ncol + offset_ncol + 2; if (use_weights_column) ncol++; if (use_start_column) ncol++; final String[] names = new String[ncol]; for (int j = 0; j < x_ncol; ++j) names[j] = source.names()[x_columns[j]]; for (int j = 0; j < offset_ncol; ++j) names[x_ncol + j] = source.names()[offset_columns[j]]; if (use_weights_column) names[x_ncol + offset_ncol] = source.names()[source.find(weights_column)]; if (use_start_column) names[ncol - 3] = source.names()[source.find(start_column)]; names[ncol - 2] = source.names()[source.find(stop_column)]; names[ncol - 1] = source.names()[source.find(event_column)]; return source.subframe(names); } protected static class CoxPHTask extends FrameTask<CoxPHTask> { private final double[] _beta; private final int _n_time; private final long _min_time; private final int _n_offsets; private final boolean _has_start_column; private final boolean _has_weights_column; protected long n; protected long n_missing; protected double sumWeights; protected double[] sumWeightedCatX; protected double[] sumWeightedNumX; protected double[] sizeRiskSet; protected double[] sizeCensored; protected double[] sizeEvents; protected long[] countEvents; protected double[][] sumXEvents; protected double[] sumRiskEvents; protected double[][] sumXRiskEvents; protected double[][][] sumXXRiskEvents; protected double[] sumLogRiskEvents; protected double[] rcumsumRisk; protected double[][] rcumsumXRisk; protected double[][][] rcumsumXXRisk; CoxPHTask(Key jobKey, DataInfo dinfo, final double[] beta, final long min_time, final int n_time, final int n_offsets, final boolean has_start_column, final boolean has_weights_column) { super(jobKey, dinfo); _beta = beta; _n_time = n_time; _min_time = min_time; _n_offsets = n_offsets; _has_start_column = has_start_column; _has_weights_column = has_weights_column; } @Override protected void chunkInit(){ final int n_coef = _beta.length; sumWeightedCatX = MemoryManager.malloc8d(n_coef - (_dinfo._nums - _n_offsets)); sumWeightedNumX = MemoryManager.malloc8d(_dinfo._nums); sizeRiskSet = MemoryManager.malloc8d(_n_time); sizeCensored = MemoryManager.malloc8d(_n_time); sizeEvents = MemoryManager.malloc8d(_n_time); countEvents = MemoryManager.malloc8(_n_time); sumRiskEvents = MemoryManager.malloc8d(_n_time); sumLogRiskEvents = MemoryManager.malloc8d(_n_time); rcumsumRisk = MemoryManager.malloc8d(_n_time); sumXEvents = malloc2DArray(_n_time, n_coef); sumXRiskEvents = malloc2DArray(_n_time, n_coef); rcumsumXRisk = malloc2DArray(_n_time, n_coef); sumXXRiskEvents = malloc3DArray(_n_time, n_coef, n_coef); rcumsumXXRisk = malloc3DArray(_n_time, n_coef, n_coef); } @Override protected void processRow(long gid, double [] nums, int ncats, int [] cats, double [] response) { n++; final double weight = _has_weights_column ? response[0] : 1.0; if (weight <= 0) throw new IllegalArgumentException("weights must be positive values"); final long event = (long) response[response.length - 1]; final int t1 = _has_start_column ? (int) (((long) response[response.length - 3] + 1) - _min_time) : -1; final int t2 = (int) (((long) response[response.length - 2]) - _min_time); if (t1 > t2) throw new IllegalArgumentException("start times must be strictly less than stop times"); final int numStart = _dinfo.numStart(); sumWeights += weight; for (int j = 0; j < ncats; ++j) sumWeightedCatX[cats[j]] += weight; for (int j = 0; j < nums.length; ++j) sumWeightedNumX[j] += weight * nums[j]; double logRisk = 0; for (int j = 0; j < ncats; ++j) logRisk += _beta[cats[j]]; for (int j = 0; j < nums.length - _n_offsets; ++j) logRisk += nums[j] * _beta[numStart + j]; for (int j = nums.length - _n_offsets; j < nums.length; ++j) logRisk += nums[j]; final double risk = weight * Math.exp(logRisk); logRisk *= weight; if (event > 0) { countEvents[t2]++; sizeEvents[t2] += weight; sumLogRiskEvents[t2] += logRisk; sumRiskEvents[t2] += risk; } else sizeCensored[t2] += weight; if (_has_start_column) { for (int t = t1; t <= t2; ++t) sizeRiskSet[t] += weight; for (int t = t1; t <= t2; ++t) rcumsumRisk[t] += risk; } else { sizeRiskSet[t2] += weight; rcumsumRisk[t2] += risk; } final int ntotal = ncats + (nums.length - _n_offsets); final int numStartIter = numStart - ncats; for (int jit = 0; jit < ntotal; ++jit) { final boolean jIsCat = jit < ncats; final int j = jIsCat ? cats[jit] : numStartIter + jit; final double x1 = jIsCat ? 1.0 : nums[jit - ncats]; final double xRisk = x1 * risk; if (event > 0) { sumXEvents[t2][j] += weight * x1; sumXRiskEvents[t2][j] += xRisk; } if (_has_start_column) { for (int t = t1; t <= t2; ++t) rcumsumXRisk[t][j] += xRisk; } else { rcumsumXRisk[t2][j] += xRisk; } for (int kit = 0; kit < ntotal; ++kit) { final boolean kIsCat = kit < ncats; final int k = kIsCat ? cats[kit] : numStartIter + kit; final double x2 = kIsCat ? 1.0 : nums[kit - ncats]; final double xxRisk = x2 * xRisk; if (event > 0) sumXXRiskEvents[t2][j][k] += xxRisk; if (_has_start_column) { for (int t = t1; t <= t2; ++t) rcumsumXXRisk[t][j][k] += xxRisk; } else { rcumsumXXRisk[t2][j][k] += xxRisk; } } } } @Override public void reduce(CoxPHTask that) { n += that.n; sumWeights += that.sumWeights; Utils.add(sumWeightedCatX, that.sumWeightedCatX); Utils.add(sumWeightedNumX, that.sumWeightedNumX); Utils.add(sizeRiskSet, that.sizeRiskSet); Utils.add(sizeCensored, that.sizeCensored); Utils.add(sizeEvents, that.sizeEvents); Utils.add(countEvents, that.countEvents); Utils.add(sumXEvents, that.sumXEvents); Utils.add(sumRiskEvents, that.sumRiskEvents); Utils.add(sumXRiskEvents, that.sumXRiskEvents); Utils.add(sumXXRiskEvents, that.sumXXRiskEvents); Utils.add(sumLogRiskEvents, that.sumLogRiskEvents); Utils.add(rcumsumRisk, that.rcumsumRisk); Utils.add(rcumsumXRisk, that.rcumsumXRisk); Utils.add(rcumsumXXRisk, that.rcumsumXXRisk); } @Override protected void postGlobal() { if (!_has_start_column) { for (int t = rcumsumRisk.length - 2; t >= 0; --t) rcumsumRisk[t] += rcumsumRisk[t + 1]; for (int t = rcumsumXRisk.length - 2; t >= 0; --t) for (int j = 0; j < rcumsumXRisk[t].length; ++j) rcumsumXRisk[t][j] += rcumsumXRisk[t + 1][j]; for (int t = rcumsumXXRisk.length - 2; t >= 0; --t) for (int j = 0; j < rcumsumXXRisk[t].length; ++j) for (int k = 0; k < rcumsumXXRisk[t][j].length; ++k) rcumsumXXRisk[t][j][k] += rcumsumXXRisk[t + 1][j][k]; } } } }