package uk.ac.rhul.cs.stats.curvefitting;
import uk.ac.rhul.cs.stats.datastructures.PairedData;
/**
* Calculates and stores the least squares straight line y = A + Bx, fitted
* to an array of data pairs (x, y).
*
* @author tamas
*/
public class LineFit implements StraightLineFit {
/** The paired data we are working with */
protected PairedData data;
/** Whether the calculation was performed already */
private boolean calculated = false;
/** Whether the post-processign analysis was performed already */
private boolean analysed = false;
/** The mean of X */
protected double meanX;
/** The mean of Y */
protected double meanY;
/** The calculated estimate of A */
protected double calculatedA;
/** The calculated estimate of B */
protected double calculatedB;
/** The calculated sum of standard deviations */
protected double rss;
/**
* Creates the least squares estimate from the given paired data.
*
* @param data: the data pairs (x, y)
*/
public LineFit(PairedData data) {
this.data = data;
}
/**
* Performs the actual calculation.
*
* This method sets up calculatedA, calculatedB, meanX and meanY
* accordingly.
*/
protected void calculate() {
if (calculated)
return;
calculated = true;
double[] xs = data.getX();
double[] ys = data.getY();
double xx_bar = 0.0, xy_bar = 0.0;
int n = xs.length;
if (n == 0) {
calculatedA = Double.NaN;
calculatedB = Double.NaN;
return;
}
/* Calculate meanX and sumsq_xs */
for (double x: xs) {
meanX += x;
}
meanX /= n;
/* Calculate meanY */
for (double y: ys) {
meanY += y;
}
meanY /= n;
/* Calculate xx_bar, yy_bar, xy_bar */
for (int i = 0; i < n; i++) {
double xdiff = xs[i] - meanX;
double ydiff = ys[i] - meanY;
xx_bar += xdiff * xdiff;
xy_bar += xdiff * ydiff;
}
calculatedA = xy_bar / xx_bar;
calculatedB = meanY - calculatedA * meanX;
}
/**
* Analyses the results
*
* This method sets up rss accordingly.
*/
protected void analyse() {
if (analysed)
return;
analysed = true;
calculate();
double[] xs = data.getX();
double[] ys = data.getY();
int n = xs.length;
rss = 0.0;
for (int i = 0; i < n; i++) {
double fit = calculatedA * xs[i] + calculatedB;
rss += (ys[i] - fit) * (ys[i] - fit);
}
}
/**
* Returns the least squares estimate of A.
*
* @return: the least squares estimate of A.
*/
public double getA() {
calculate();
return calculatedA;
}
/**
* Returns the least squares estimate of B.
*
* @return: the least squares estimate of B.
*/
public double getB() {
calculate();
return calculatedB;
}
/**
* Returns the mean of X.
*
* @return: the mean of X
*/
public double getMeanX() {
calculate();
return meanX;
}
/**
* Returns the mean of Y.
*
* @return: the mean of Y
*/
public double getMeanY() {
calculate();
return meanY;
}
/**
* Returns the sum of squared differences between the fitted line and the data
*/
public double getSumOfSquares() {
analyse();
return rss;
}
}