package uk.ac.babraham.BamQC.Utilities; /** * The <tt>LinearRegression</tt> class performs a simple linear regression * on an set of <em>N</em> data points (<em>y<sub>i</sub></em>, <em>x<sub>i</sub></em>). * That is, it fits a straight line <em>y</em> = α + β <em>x</em>, * (where <em>y</em> is the response variable, <em>x</em> is the predictor variable, * α is the <em>y-intercept</em>, and β is the <em>slope</em>) * that minimizes the sum of squared residuals of the linear regression model. * It also computes associated statistics, including the coefficient of * determination <em>R</em><sup>2</sup> and the standard deviation of the * estimates for the slope and <em>y</em>-intercept. * * Source code: http://algs4.cs.princeton.edu/14analysis/LinearRegression.java * (GPLv3 License) * * @author Robert Sedgewick * @author Kevin Wayne */ public class LinearRegression { private final int N; private final double intercept, slope; private final double R2; private final double svar, svar0, svar1; /** * Performs a linear regression on the data points <tt>(y[i], x[i])</tt>. * * @param x the values of the predictor variable * @param y the corresponding values of the response variable * @throws IllegalArgumentException if the lengths of the two arrays are not equal */ public LinearRegression(double[] x, double[] y) { if (x.length != y.length) { throw new IllegalArgumentException("Array lengths are not equal"); } N = x.length; // first pass double sumx = 0.0, sumy = 0.0; for (int i = 0; i < N; i++) sumx += x[i]; for (int i = 0; i < N; i++) { } for (int i = 0; i < N; i++) sumy += y[i]; double xbar = sumx / N; double ybar = sumy / N; // second pass: compute summary statistics double xxbar = 0.0, yybar = 0.0, xybar = 0.0; for (int i = 0; i < N; i++) { xxbar += (x[i] - xbar) * (x[i] - xbar); yybar += (y[i] - ybar) * (y[i] - ybar); xybar += (x[i] - xbar) * (y[i] - ybar); } slope = xybar / xxbar; intercept = ybar - slope * xbar; // more statistical analysis double rss = 0.0; // residual sum of squares double ssr = 0.0; // regression sum of squares for (int i = 0; i < N; i++) { double fit = slope*x[i] + intercept; rss += (fit - y[i]) * (fit - y[i]); ssr += (fit - ybar) * (fit - ybar); } int degreesOfFreedom = N-2; R2 = ssr / yybar; svar = rss / degreesOfFreedom; svar1 = svar / xxbar; svar0 = svar/N + xbar*xbar*svar1; } /** * Returns the <em>y</em>-intercept α of the best of the best-fit line <em>y</em> = α + β <em>x</em>. * * @return the <em>y</em>-intercept α of the best-fit line <em>y = α + β x</em> */ public double intercept() { return intercept; } /** * Returns the slope β of the best of the best-fit line <em>y</em> = α + β <em>x</em>. * * @return the slope β of the best-fit line <em>y</em> = α + β <em>x</em> */ public double slope() { return slope; } /** * Returns the coefficient of determination <em>R</em><sup>2</sup>. * * @return the coefficient of determination <em>R</em><sup>2</sup>, * which is a real number between 0 and 1 */ public double R2() { return R2; } /** * Returns the standard error of the estimate for the intercept. * * @return the standard error of the estimate for the intercept */ public double interceptStdErr() { return Math.sqrt(svar0); } /** * Returns the standard error of the estimate for the slope. * * @return the standard error of the estimate for the slope */ public double slopeStdErr() { return Math.sqrt(svar1); } /** * Returns the expected response <tt>y</tt> given the value of the predictor * variable <tt>x</tt>. * * @param x the value of the predictor variable * @return the expected response <tt>y</tt> given the value of the predictor * variable <tt>x</tt> */ public double predict(double x) { return slope*x + intercept; } /** * Returns a string representation of the simple linear regression model. * * @return a string representation of the simple linear regression model, * including the best-fit line and the coefficient of determination * <em>R</em><sup>2</sup> */ @Override public String toString() { String s = ""; s += String.format("%.2f N + %.2f", slope(), intercept()); return s + " (R^2 = " + String.format("%.3f", R2()) + ")"; } }