package hex.genmodel.algos.glrm; import hex.ModelCategory; import hex.genmodel.MojoModel; import java.util.EnumSet; import java.util.Random; /** */ public class GlrmMojoModel extends MojoModel { public int _ncolA; public int _ncolX; public int _ncolY; public int _nrowY; public double[][] _archetypes; public int[] _numLevels; public int[] _permutation; public GlrmLoss[] _losses; public GlrmRegularizer _regx; public double _gammax; public GlrmInitialization _init; public int _ncats; public int _nnums; public double[] _normSub; public double[] _normMul; // We don't really care about regularization of Y since it is not used during scoring /** * This is the "learning rate" in the gradient descent method. More specifically, at each iteration step we update * x according to x_new = x_old - alpha * grad_x(obj)(x_old). If the objective evaluated at x_new is smaller than * the objective at x_old, then we proceed with the update, increasing alpha slightly (in case we learn too slowly); * however if the objective at x_new is bigger than the original objective, then we "overshot" and therefore reduce * alpha in half. * When reusing the alpha between multiple computations of the gradient, we find that alpha eventually "stabilizes" * in a certain range; moreover that range is roughly the same when scoring different rows. This is why alpha was * made static -- so that its value from previous scoring round can be reused to achieve faster convergence. * This approach is not thread-safe! If we ever make GenModel capable of scoring multiple rows in parallel, this * will have to be changed to make updates to alpha synchronized. */ private static double alpha = 1.0; private static final double DOWN_FACTOR = 0.5; private static final double UP_FACTOR = Math.pow(1.0/DOWN_FACTOR, 1.0/4); static { //noinspection ConstantAssertCondition,ConstantConditions assert DOWN_FACTOR < 1 && DOWN_FACTOR > 0; assert UP_FACTOR > 1; } private static EnumSet<ModelCategory> CATEGORIES = EnumSet.of(ModelCategory.AutoEncoder, ModelCategory.DimReduction); @Override public EnumSet<ModelCategory> getModelCategories() { return CATEGORIES; } protected GlrmMojoModel(String[] columns, String[][] domains) { super(columns, domains); } @Override public int getPredsSize(ModelCategory mc) { return _ncolX; } /** * This function corresponds to the DimReduction model category */ @Override public double[] score0(double[] row, double[] preds) { assert row.length == _ncolA; assert preds.length == _ncolX; assert _nrowY == _ncolX; assert _archetypes.length == _nrowY; assert _archetypes[0].length == _ncolY; // Step 0: prepare the data row double[] a = new double[_ncolA]; for (int i = 0; i < _ncolA; i++) a[i] = row[_permutation[i]]; // Step 1: initialize X (for now do Random initialization only) double[] x = new double[_ncolX]; Random random = new Random(); for (int i = 0; i < _ncolX; i++) x[i] = random.nextGaussian(); x = _regx.project(x, random); // Step 2: update X based on prox-prox algorithm, iterate until convergence double obj = objective(x, a); boolean done = false; int iters = 0; while (!done && iters++ < 100) { // Compute the gradient of the loss function double[] grad = gradientL(x, a); // Try to make a step of size alpha, until we can achieve improvement in the objective. double[] u = new double[_ncolX]; while (true) { // System.out.println(" " + alpha); // Compute the tentative new x (using the prox algorithm) for (int k = 0; k < _ncolX; k++) { u[k] = x[k] - alpha * grad[k]; } double[] xnew = _regx.rproxgrad(u, alpha * _gammax, random); double newobj = objective(xnew, a); if (newobj == 0) break; double obj_improvement = 1 - newobj/obj; if (obj_improvement >= 0) { if (obj_improvement < 1e-6) done = true; obj = newobj; x = xnew; alpha *= UP_FACTOR; break; } else { alpha *= DOWN_FACTOR; } } } // Step 3: return the result // System.out.println("obj = " + obj + ", alpha = " + alpha + ", n_iters = " + iters); System.arraycopy(x, 0, preds, 0, _ncolX); return preds; } /** * Compute gradient of the objective function with respect to x, i.e. d/dx Sum_j[L_j(xY_j, a)] * @param x: current x row * @param a: the adapted data row */ private double[] gradientL(double[] x, double[] a) { // Prepate output row double[] grad = new double[_ncolX]; // Categorical columns int cat_offset = 0; for (int j = 0; j < _ncats; j++) { if (Double.isNaN(a[j])) continue; // Skip missing observations in row (???) int n_levels = _numLevels[j]; // Calculate xy = x * Y_j where Y_j is sub-matrix corresponding to categorical col j double[] xy = new double[n_levels]; for (int level = 0; level < n_levels; level++) { for (int k = 0; k < _ncolX; k++) { xy[level] += x[k] * _archetypes[k][level + cat_offset]; } } // Gradient wrt x is matrix product \grad L_j(x * Y_j, A_j) * Y_j' double[] gradL = _losses[j].mlgrad(xy, (int) a[j]); for (int k = 0; k < _ncolX; k++) { for (int c = 0; c < n_levels; c++) grad[k] += gradL[c] * _archetypes[k][c + cat_offset]; } cat_offset += n_levels; } // Numeric columns for (int j = _ncats; j < _ncolA; j++) { int js = j - _ncats; if (Double.isNaN(a[j])) continue; // Skip missing observations in row // Inner product x * y_j double xy = 0; for (int k = 0; k < _ncolX; k++) xy += x[k] * _archetypes[k][js + cat_offset]; // Sum over y_j weighted by gradient of loss \grad L_j(x * y_j, A_j) double gradL = _losses[j].lgrad(xy, (a[j] - _normSub[js]) * _normMul[js]); for (int k = 0; k < _ncolX; k++) grad[k] += gradL * _archetypes[k][js + cat_offset]; } return grad; } private double objective(double[] x, double[] a) { double res = 0; // Loss: Categorical columns int cat_offset = 0; for (int j = 0; j < _ncats; j++) { if (Double.isNaN(a[j])) continue; // Skip missing observations in row int n_levels = _numLevels[j]; double[] xy = new double[n_levels]; for (int level = 0; level < n_levels; level++) { for (int k = 0; k < _ncolX; k++) { xy[level] += x[k] * _archetypes[k][level + cat_offset]; } } res += _losses[j].mloss(xy, (int) a[j]); cat_offset += n_levels; } // Loss: Numeric columns for (int j = _ncats; j < _ncolA; j++) { int js = j - _ncats; if (Double.isNaN(a[j])) continue; // Skip missing observations in row double xy = 0; for (int k = 0; k < _ncolX; k++) xy += x[k] * _archetypes[k][js + cat_offset]; res += _losses[j].loss(xy, (a[j] - _normSub[js]) * _normMul[js]); } res += _gammax * _regx.regularize(x); return res; } }