/******************************************************************************* * Copyright (c) 2010 Haifeng Li * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ package smile.regression; import java.io.Serializable; import java.util.Arrays; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import smile.math.Math; import smile.math.matrix.Matrix; import smile.math.matrix.NaiveMatrix; import smile.math.matrix.RowMajorMatrix; import smile.math.matrix.SparseMatrix; import smile.math.matrix.BiconjugateGradient; import smile.math.matrix.Preconditioner; import smile.math.special.Beta; /** * Lasso (least absolute shrinkage and selection operator) regression. * The Lasso is a shrinkage and selection method for linear regression. * It minimizes the usual sum of squared errors, with a bound on the sum * of the absolute values of the coefficients (i.e. L<sub>1</sub>-regularized). * It has connections to soft-thresholding of wavelet coefficients, forward * stage-wise regression, and boosting methods. * <p> * The Lasso typically yields a sparse solution, of which the parameter * vector β has relatively few nonzero coefficients. In contrast, the * solution of L<sub>2</sub>-regularized least squares (i.e. ridge regression) * typically has all coefficients nonzero. Because it effectively * reduces the number of variables, the Lasso is useful in some contexts. * <p> * For over-determined systems (more instances than variables, commonly in * machine learning), we normalize variables with mean 0 and standard deviation * 1. For under-determined systems (less instances than variables, e.g. * compressed sensing), we assume white noise (i.e. no intercept in the linear * model) and do not perform normalization. Note that the solution * is not unique in this case. * <p> * There is no analytic formula or expression for the optimal solution to the * L<sub>1</sub>-regularized least squares problems. Therefore, its solution * must be computed numerically. The objective function in the * L<sub>1</sub>-regularized least squares is convex but not differentiable, * so solving it is more of a computational challenge than solving the * L<sub>2</sub>-regularized least squares. The Lasso may be solved using * quadratic programming or more general convex optimization methods, as well * as by specific algorithms such as the least angle regression algorithm. * * <h2>References</h2> * <ol> * <li> R. Tibshirani. Regression shrinkage and selection via the lasso. J. Royal. Statist. Soc B., 58(1):267-288, 1996.</li> * <li> B. Efron, I. Johnstone, T. Hastie, and R. Tibshirani. Least angle regression. Annals of Statistics, 2003 </li> * <li> Seung-Jean Kim, K. Koh, M. Lustig, Stephen Boyd, and Dimitry Gorinevsky. An Interior-Point Method for Large-Scale L1-Regularized Least Squares. IEEE JOURNAL OF SELECTED TOPICS IN SIGNAL PROCESSING, VOL. 1, NO. 4, 2007.</li> * </ol> * * @author Haifeng Li */ public class LASSO implements Regression<double[]>, Serializable { private static final long serialVersionUID = 1L; private static final Logger logger = LoggerFactory.getLogger(LASSO.class); /** * The dimensionality. */ private int p; /** * The shrinkage/regularization parameter. */ private double lambda; /** * The intercept. */ private double b; /** * The linear coefficients. */ private double[] w; /** * The mean of response variable. */ private double ym; /** * The center of input vector. The input vector should be centered * before prediction. */ private double[] center; /** * Scaling factor of input vector. */ private double[] scale; /** * The residuals, that is response minus fitted values. */ private double[] residuals; /** * Residual sum of squares. */ private double RSS; /** * Residual standard error. */ private double error; /** * The degree-of-freedom of residual standard error. */ private int df; /** * R<sup>2</sup>. R<sup>2</sup> is a statistic that will give some information * about the goodness of fit of a model. In regression, the R<sup>2</sup> * coefficient of determination is a statistical measure of how well * the regression line approximates the real data points. An R<sup>2</sup> * of 1.0 indicates that the regression line perfectly fits the data. * <p> * In the case of ordinary least-squares regression, R<sup>2</sup> * increases as we increase the number of variables in the model * (R<sup>2</sup> will not decrease). This illustrates a drawback to * one possible use of R<sup>2</sup>, where one might try to include * more variables in the model until "there is no more improvement". * This leads to the alternative approach of looking at the * adjusted R<sup>2</sup>. */ private double RSquared; /** * Adjusted R<sup>2</sup>. The adjusted R<sup>2</sup> has almost same * explanation as R<sup>2</sup> but it penalizes the statistic as * extra variables are included in the model. */ private double adjustedRSquared; /** * The F-statistic of the goodness-of-fit of the model. */ private double F; /** * The p-value of the goodness-of-fit test of the model. */ private double pvalue; /** * Trainer for LASSO regression. */ public static class Trainer extends RegressionTrainer<double[]> { /** * The shrinkage/regularization parameter. */ private double lambda; /** * The tolerance for stopping iterations (relative target duality gap). */ private double tol = 1E-3; /** * The maximum number of IPM (Newton) iterations. */ private int maxIter = 1000; /** * Constructor. * * @param lambda the number of trees. */ public Trainer(double lambda) { if (lambda < 0.0) { throw new IllegalArgumentException("Invalid shrinkage/regularization parameter lambda = " + lambda); } this.lambda = lambda; } /** * Sets the tolerance for stopping iterations (relative target duality gap). * * @param tol the tolerance for stopping iterations. */ public Trainer setTolerance(double tol) { if (tol <= 0.0) { throw new IllegalArgumentException("Invalid tolerance: " + tol); } this.tol = tol; return this; } /** * Sets the maximum number of iterations. * * @param maxIter the maximum number of iterations. */ public Trainer setMaxNumIteration(int maxIter) { if (maxIter <= 0) { throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter); } this.maxIter = maxIter; return this; } @Override public LASSO train(double[][] x, double[] y) { return new LASSO(x, y, lambda, tol, maxIter); } public LASSO train(Matrix x, double[] y) { return new LASSO(x, y, lambda, tol, maxIter); } } /** * Constructor. Learn the L1-regularized least squares model. * @param x a matrix containing the explanatory variables. * NO NEED to include a constant column of 1s for bias. * @param y the response values. * @param lambda the shrinkage/regularization parameter. */ public LASSO(double[][] x, double[] y, double lambda) { this(x, y, lambda, 1E-4, 1000); } /** * Constructor. Learn the L1-regularized least squares model. * @param x a matrix containing the explanatory variables. * NO NEED to include a constant column of 1s for bias. * @param y the response values. * @param lambda the shrinkage/regularization parameter. * @param tol the tolerance for stopping iterations (relative target duality gap). * @param maxIter the maximum number of IPM (Newton) iterations. */ public LASSO(double[][] x, double[] y, double lambda, double tol, int maxIter) { int n = x.length; int p = x[0].length; center = Math.colMean(x); RowMajorMatrix X = new RowMajorMatrix(n, p); for (int i = 0; i < n; i++) { for (int j = 0; j < p; j++) { X.set(i, j, x[i][j] - center[j]); } } scale = new double[p]; for (int j = 0; j < p; j++) { for (int i = 0; i < n; i++) { scale[j] += Math.sqr(X.get(i, j)); } scale[j] = Math.sqrt(scale[j] / n); } for (int j = 0; j < p; j++) { if (!Math.isZero(scale[j])) { for (int i = 0; i < n; i++) { X.div(i, j, scale[j]); } } } train(X, y, lambda, tol, maxIter); for (int j = 0; j < p; j++) { if (!Math.isZero(scale[j])) { w[j] /= scale[j]; } } b = ym - Math.dot(w, center); fitness(new NaiveMatrix(x), y); } /** * Constructor. Learn the L1-regularized least squares model. * @param x a matrix containing the explanatory variables. The variables should be * centered and standardized. NO NEED to include a constant column of 1s for bias. * @param y the response values. * @param lambda the shrinkage/regularization parameter. */ public LASSO(Matrix x, double[] y, double lambda) { this(x, y, lambda, 1E-4, 1000); } /** * Constructor. Learn the L1-regularized least squares model. * @param x a matrix containing the explanatory variables. The variables should be * centered and standardized. NO NEED to include a constant column of 1s for bias. * @param y the response values. * @param lambda the shrinkage/regularization parameter. * @param tol the tolerance for stopping iterations (relative target duality gap). * @param maxIter the maximum number of IPM (Newton) iterations. */ public LASSO(Matrix x, double[] y, double lambda, double tol, int maxIter) { train(x, y, lambda, tol, maxIter); fitness(x, y); } private void train(Matrix x, double[] y, double lambda, double tol, int maxIter) { if (x.nrows() != y.length) { throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.nrows(), y.length)); } if (lambda <= 0.0) { throw new IllegalArgumentException("Invalid shrinkage/regularization parameter lambda = " + lambda); } if (tol <= 0.0) { throw new IllegalArgumentException("Invalid tolerance: " + tol); } if (maxIter <= 0) { throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter); } // INITIALIZE // IPM PARAMETERS final int MU = 2; // updating parameter of t // LINE SEARCH PARAMETERS final double ALPHA = 0.01; // minimum fraction of decrease in the objective final double BETA = 0.5; // stepsize decrease factor final int MAX_LS_ITER = 100; // maximum backtracking line search iteration final int pcgmaxi = 5000; // maximum number of maximum PCG iterations final double eta = 1E-3; // tolerance for PCG termination int pitr = 0; int n = x.nrows(); p = x.ncols(); double[] Y = new double[n]; ym = Math.mean(y); for (int i = 0; i < n; i++) { Y[i] = y[i] - ym; } double t = Math.min(Math.max(1.0, 1.0 / lambda), 2 * p / 1e-3); double pobj = 0.0; // primal objective function value double dobj = Double.NEGATIVE_INFINITY; // dual objective function value double s = Double.POSITIVE_INFINITY; w = new double[p]; b = ym; double[] u = new double[p]; double[] z = new double[n]; double[][] f = new double[2][p]; Arrays.fill(u, 1.0); for (int i = 0; i < p; i++) { f[0][i] = w[i] - u[i]; f[1][i] = -w[i] - u[i]; } double[] neww = new double[p]; double[] newu = new double[p]; double[] newz = new double[n]; double[][] newf = new double[2][p]; double[] dx = new double[p]; double[] du = new double[p]; double[] dxu = new double[2 * p]; double[] grad = new double[2 * p]; // diagxtx = diag(X'X) // X has been standardized so that diag(X'X) is just 1.0. // Here we initialize it to 2.0 because we actually need 2 * diag(X'X) // during optimization. double[] diagxtx = new double[p]; Arrays.fill(diagxtx, 2.0); double[] nu = new double[n]; double[] xnu = new double[p]; double[] q1 = new double[p]; double[] q2 = new double[p]; double[] d1 = new double[p]; double[] d2 = new double[p]; double[][] gradphi = new double[2][p]; double[] prb = new double[p]; double[] prs = new double[p]; PCGMatrix pcg = new PCGMatrix(x, d1, d2, prb, prs); // MAIN LOOP int ntiter = 0; for (; ntiter <= maxIter; ntiter++) { x.ax(w, z); for (int i = 0; i < n; i++) { z[i] -= Y[i]; nu[i] = 2 * z[i]; } // CALCULATE DUALITY GAP x.atx(nu, xnu); double maxXnu = Math.normInf(xnu); if (maxXnu > lambda) { double lnu = lambda / maxXnu; for (int i = 0; i < n; i++) { nu[i] *= lnu; } } pobj = Math.dot(z, z) + lambda * Math.norm1(w); dobj = Math.max(-0.25 * Math.dot(nu, nu) - Math.dot(nu, Y), dobj); if (ntiter % 10 == 0) { logger.info(String.format("LASSO: primal and dual objective function value after %3d iterations: %.5g\t%.5g%n", ntiter, pobj, dobj)); } double gap = pobj - dobj; // STOPPING CRITERION if (gap / dobj < tol) { logger.info(String.format("LASSO: primal and dual objective function value after %3d iterations: %.5g\t%.5g%n", ntiter, pobj, dobj)); break; } // UPDATE t if (s >= 0.5) { t = Math.max(Math.min(2 * p * MU / gap, MU * t), t); } // CALCULATE NEWTON STEP for (int i = 0; i < p; i++) { double q1i = 1.0 / (u[i] + w[i]); double q2i = 1.0 / (u[i] - w[i]); q1[i] = q1i; q2[i] = q2i; d1[i] = (q1i * q1i + q2i * q2i) / t; d2[i] = (q1i * q1i - q2i * q2i) / t; } // calculate gradient x.atx(z, gradphi[0]); for (int i = 0; i < p; i++) { gradphi[0][i] = 2 * gradphi[0][i] - (q1[i] - q2[i]) / t; gradphi[1][i] = lambda - (q1[i] + q2[i]) / t; grad[i] = -gradphi[0][i]; grad[i + p] = -gradphi[1][i]; } // calculate vectors to be used in the preconditioner for (int i = 0; i < p; i++) { prb[i] = diagxtx[i] + d1[i]; prs[i] = prb[i] * d1[i] - d2[i] * d2[i]; } // set pcg tolerance (relative) double normg = Math.norm(grad); double pcgtol = Math.min(0.1, eta * gap / Math.min(1.0, normg)); if (ntiter != 0 && pitr == 0) { pcgtol = pcgtol * 0.1; } // preconditioned conjugate gradient double error = BiconjugateGradient.solve(pcg, pcg, grad, dxu, pcgtol, 1, pcgmaxi); if (error > pcgtol) { pitr = pcgmaxi; } for (int i = 0; i < p; i++) { dx[i] = dxu[i]; du[i] = dxu[i + p]; } // BACKTRACKING LINE SEARCH double phi = Math.dot(z, z) + lambda * Math.sum(u) - sumlogneg(f) / t; s = 1.0; double gdx = Math.dot(grad, dxu); int lsiter = 0; for (; lsiter < MAX_LS_ITER; lsiter++) { for (int i = 0; i < p; i++) { neww[i] = w[i] + s * dx[i]; newu[i] = u[i] + s * du[i]; newf[0][i] = neww[i] - newu[i]; newf[1][i] = -neww[i] - newu[i]; } if (Math.max(newf) < 0.0) { x.ax(neww, newz); for (int i = 0; i < n; i++) { newz[i] -= Y[i]; } double newphi = Math.dot(newz, newz) + lambda * Math.sum(newu) - sumlogneg(newf) / t; if (newphi - phi <= ALPHA * s * gdx) { break; } } s = BETA * s; } if (lsiter == MAX_LS_ITER) { logger.error("LASSO: Too many iterations of line search."); break; } System.arraycopy(neww, 0, w, 0, p); System.arraycopy(newu, 0, u, 0, p); System.arraycopy(newf[0], 0, f[0], 0, p); System.arraycopy(newf[1], 0, f[1], 0, p); } if (ntiter == maxIter) { logger.error("LASSO: Too many iterations."); } } private void fitness(Matrix x, double[] y) { int n = y.length; double[] yhat = new double[n]; x.ax(w, yhat); double TSS = 0.0; RSS = 0.0; double ybar = Math.mean(y); residuals = new double[n]; for (int i = 0; i < n; i++) { double r = y[i] - yhat[i] - b; residuals[i] = r; RSS += Math.sqr(r); TSS += Math.sqr(y[i] - ybar); } error = Math.sqrt(RSS / (n - p - 1)); df = n - p - 1; RSquared = 1.0 - RSS / TSS; adjustedRSquared = 1.0 - ((1 - RSquared) * (n-1) / (n-p-1)); F = (TSS - RSS) * (n - p - 1) / (RSS * p); int df1 = p; int df2 = n - p - 1; pvalue = Beta.regularizedIncompleteBetaFunction(0.5 * df2, 0.5 * df1, df2 / (df2 + df1 * F)); } /** * Returns sum(log(-f)). * @param f a matrix. * @return sum(log(-f)) */ private double sumlogneg(double[][] f) { int m = f.length; int n = f[0].length; double sum = 0.0; for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { sum += Math.log(-f[i][j]); } } return sum; } class PCGMatrix implements Matrix, Preconditioner { Matrix A; Matrix AtA; double[] d1; double[] d2; double[] prb; double[] prs; double[] ax; double[] atax; PCGMatrix(Matrix A, double[] d1, double[] d2, double[] prb, double[] prs) { this.A = A; this.d1 = d1; this.d2 = d2; this.prb = prb; this.prs = prs; int n = A.nrows(); ax = new double[n]; atax = new double[p]; if ((A.ncols() < 10000) && !(A instanceof SparseMatrix)) AtA = A.ata(); } @Override public int nrows() { return 2 * p; } @Override public int ncols() { return 2 * p; } @Override public double[] ax(double[] x, double[] y) { // COMPUTE AX (PCG) // // y = hessphi * x, // // where hessphi = [A'*A*2+D1 , D2; // D2 , D1]; if (AtA != null) { AtA.ax(x, atax); } else { A.ax(x, ax); A.atx(ax, atax); } for (int i = 0; i < p; i++) { y[i] = 2 * atax[i] + d1[i] * x[i] + d2[i] * x[i + p]; y[i + p] = d2[i] * x[i] + d1[i] * x[i + p]; } return y; } @Override public double[] atx(double[] x, double[] y) { return ax(x, y); } @Override public void asolve(double[] b, double[] x) { // COMPUTE P^{-1}X (PCG) // // y = P^{-1} * x for (int i = 0; i < p; i++) { x[i] = ( d1[i] * b[i] - d2[i] * b[i+p]) / prs[i]; x[i+p] = (-d2[i] * b[i] + prb[i] * b[i+p]) / prs[i]; } } @Override public Matrix transpose() { throw new UnsupportedOperationException("Not supported yet."); } @Override public Matrix aat() { throw new UnsupportedOperationException("Not supported yet."); } @Override public Matrix ata() { throw new UnsupportedOperationException("Not supported yet."); } @Override public double get(int i, int j) { throw new UnsupportedOperationException("Not supported yet."); } @Override public double apply(int i, int j) { throw new UnsupportedOperationException("Not supported yet."); } @Override public double[] axpy(double[] x, double[] y) { throw new UnsupportedOperationException("Not supported yet."); } @Override public double[] axpy(double[] x, double[] y, double b) { throw new UnsupportedOperationException("Not supported yet."); } @Override public double[] atxpy(double[] x, double[] y) { throw new UnsupportedOperationException("Not supported yet."); } @Override public double[] atxpy(double[] x, double[] y, double b) { throw new UnsupportedOperationException("Not supported yet."); } } /** * Returns the linear coefficients. */ public double[] coefficients() { return w; } /** * Returns the intercept. */ public double intercept() { return b; } /** * Returns the shrinkage parameter. */ public double shrinkage() { return lambda; } @Override public double predict(double[] x) { if (x.length != p) { throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, p)); } return Math.dot(x, w) + b; } /** * Returns the residuals, that is response minus fitted values. */ public double[] residuals() { return residuals; } /** * Returns the residual sum of squares. */ public double RSS() { return RSS; } /** * Returns the residual standard error. */ public double error() { return error; } /** * Returns the degree-of-freedom of residual standard error. */ public int df() { return df; } /** * Returns R<sup>2</sup> statistic. In regression, the R<sup>2</sup> * coefficient of determination is a statistical measure of how well * the regression line approximates the real data points. An R<sup>2</sup> * of 1.0 indicates that the regression line perfectly fits the data. * <p> * In the case of ordinary least-squares regression, R<sup>2</sup> * increases as we increase the number of variables in the model * (R<sup>2</sup> will not decrease). This illustrates a drawback to * one possible use of R<sup>2</sup>, where one might try to include more * variables in the model until "there is no more improvement". This leads * to the alternative approach of looking at the adjusted R<sup>2</sup>. */ public double RSquared() { return RSquared; } /** * Returns adjusted R<sup>2</sup> statistic. The adjusted R<sup>2</sup> * has almost same explanation as R<sup>2</sup> but it penalizes the * statistic as extra variables are included in the model. */ public double adjustedRSquared() { return adjustedRSquared; } /** * Returns the F-statistic of goodness-of-fit. */ public double ftest() { return F; } /** * Returns the p-value of goodness-of-fit test. */ public double pvalue() { return pvalue; } @Override public String toString() { StringBuilder builder = new StringBuilder(); builder.append("LASSO:\n"); double[] r = residuals.clone(); builder.append("\nResiduals:\n"); builder.append("\t Min\t 1Q\t Median\t 3Q\t Max\n"); builder.append(String.format("\t%10.4f\t%10.4f\t%10.4f\t%10.4f\t%10.4f%n", Math.min(r), Math.q1(r), Math.median(r), Math.q3(r), Math.max(r))); builder.append("\nCoefficients:\n"); builder.append(" Estimate\n"); builder.append(String.format("Intercept%11.4f%n", b)); for (int i = 0; i < p; i++) { builder.append(String.format("Var %d\t %11.4f%n", i+1, w[i])); } builder.append(String.format("\nResidual standard error: %.4f on %d degrees of freedom%n", error, df)); builder.append(String.format("Multiple R-squared: %.4f, Adjusted R-squared: %.4f%n", RSquared, adjustedRSquared)); builder.append(String.format("F-statistic: %.4f on %d and %d DF, p-value: %.4g%n", F, p, df, pvalue)); return builder.toString(); } }