package hex;
import water.*;
import water.api.DocGen;
import water.fvec.*;
import water.util.RString;
public class LR2 extends Request2 {
static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields
static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code.
// This Request supports the HTML 'GET' command, and this is the help text
// for GET.
static final String DOC_GET = "Linear Regression between 2 columns";
@API(help="Data Frame", required=true, filter=Default.class)
Frame source;
@API(help="Column X", required=true, filter=LR2VecSelect.class)
Vec vec_x;
@API(help="Column Y", required=true, filter=LR2VecSelect.class)
Vec vec_y;
class LR2VecSelect extends VecSelect { LR2VecSelect() { super("source"); } }
@API(help="Pass 1 msec") long pass1time;
@API(help="Pass 2 msec") long pass2time;
@API(help="Pass 3 msec") long pass3time;
@API(help="nrows") long nrows;
@API(help="beta0") double beta0;
@API(help="beta1") double beta1;
@API(help="r-squared") double r2;
@API(help="SSTO") double ssto;
@API(help="SSE") double sse;
@API(help="SSR") double ssr;
@API(help="beta0 Std Error") double beta0stderr;
@API(help="beta1 Std Error") double beta1stderr;
@Override public Response serve() {
// Pass 1: compute sums & sums-of-squares
long start = System.currentTimeMillis();
CalcSumsTask lr1 = new CalcSumsTask().doAll(vec_x, vec_y);
long pass1 = System.currentTimeMillis();
pass1time = pass1 - start;
nrows = lr1._n;
// Pass 2: Compute squared errors
final double meanX = lr1._sumX/nrows;
final double meanY = lr1._sumY/nrows;
CalcSquareErrorsTasks lr2 = new CalcSquareErrorsTasks(meanX, meanY).doAll(vec_x, vec_y);
long pass2 = System.currentTimeMillis();
pass2time = pass2 - pass1;
ssto = lr2._YYbar;
// Compute the regression
beta1 = lr2._XYbar / lr2._XXbar;
beta0 = meanY - beta1 * meanX;
CalcRegressionTask lr3 = new CalcRegressionTask(beta0, beta1, meanY).doAll(vec_x, vec_y);
long pass3 = System.currentTimeMillis();
pass3time = pass3 - pass2;
long df = nrows - 2;
r2 = lr3._ssr / lr2._YYbar;
double svar = lr3._rss / df;
double svar1 = svar / lr2._XXbar;
double svar0 = svar/nrows + meanX*meanX*svar1;
beta0stderr = Math.sqrt(svar0);
beta1stderr = Math.sqrt(svar1);
sse = lr3._rss;
ssr = lr3._ssr;
return Response.done(this);
}
public static class CalcSumsTask extends MRTask2<CalcSumsTask> {
long _n; // Rows used
double _sumX,_sumY,_sumX2; // Sum of X's, Y's, X^2's
@Override public void map( Chunk xs, Chunk ys ) {
for( int i=0; i<xs._len; i++ ) {
double X = xs.at0(i);
double Y = ys.at0(i);
if( !Double.isNaN(X) && !Double.isNaN(Y)) {
_sumX += X;
_sumY += Y;
_sumX2+= X*X;
_n++;
}
}
}
@Override public void reduce( CalcSumsTask lr1 ) {
_sumX += lr1._sumX ;
_sumY += lr1._sumY ;
_sumX2+= lr1._sumX2;
_n += lr1._n;
}
}
public static class CalcSquareErrorsTasks extends MRTask2<CalcSquareErrorsTasks> {
final double _meanX, _meanY;
double _XXbar, _YYbar, _XYbar;
CalcSquareErrorsTasks( double meanX, double meanY ) { _meanX = meanX; _meanY = meanY; }
@Override public void map( Chunk xs, Chunk ys ) {
for( int i=0; i<xs._len; i++ ) {
double Xa = xs.at0(i);
double Ya = ys.at0(i);
if(!Double.isNaN(Xa) && !Double.isNaN(Ya)) {
Xa -= _meanX;
Ya -= _meanY;
_XXbar += Xa*Xa;
_YYbar += Ya*Ya;
_XYbar += Xa*Ya;
}
}
}
@Override public void reduce( CalcSquareErrorsTasks lr2 ) {
_XXbar += lr2._XXbar;
_YYbar += lr2._YYbar;
_XYbar += lr2._XYbar;
}
}
public static class CalcRegressionTask extends MRTask2<CalcRegressionTask> {
final double _meanY;
final double _beta0, _beta1;
double _rss, _ssr;
CalcRegressionTask(double beta0, double beta1, double meanY) {_beta0=beta0; _beta1=beta1; _meanY=meanY;}
@Override public void map( Chunk xs, Chunk ys ) {
for( int i=0; i<xs._len; i++ ) {
double X = xs.at0(i); double Y = ys.at0(i);
if( !Double.isNaN(X) && !Double.isNaN(Y) ) {
double fit = _beta1*X + _beta0;
double rs = fit-Y;
_rss += rs*rs;
double sr = fit-_meanY;
_ssr += sr*sr;
}
}
}
@Override public void reduce( CalcRegressionTask lr3 ) {
_rss += lr3._rss;
_ssr += lr3._ssr;
}
}
/** Return the query link to this page */
public static String link(Key k, String content) {
RString rs = new RString("<a href='LR2.query?data_key=%$key'>%content</a>");
rs.replace("key", k.toString());
rs.replace("content", content);
return rs.toString();
}
}