/*
* Author: tdanford
* Date: Aug 27, 2008
*/
package org.seqcode.ml.regression;
import java.util.*;
import java.io.*;
import java.util.regex.*;
import org.seqcode.gseutils.BitVector;
import org.seqcode.gseutils.Predicate;
import org.seqcode.gseutils.models.*;
import java.lang.reflect.*;
import cern.jet.random.ChiSquare;
import cern.jet.random.Normal;
import cern.jet.random.engine.RandomEngine;
import Jama.*;
public class DataRegression<M extends Model> {
public static void main(String[] args) {
File f = new File("C:\\Documents and Settings\\tdanford\\Desktop\\test.txt");
try {
DataFrame<XYPoint> df = new DataFrame<XYPoint>(XYPoint.class, f);
RegressionModel m = new RegressionModel() {
public DependentVariable y;
public NumericVariable x;
public Intercept b;
};
DataRegression<XYPoint> reg = new DataRegression<XYPoint>(df, "y ~ x + 1");
//DataRegression<XYPoint> reg = new DataRegression<XYPoint>(df, m);
reg.transform(new ATransformation<XYPoint,XYPoint>(XYPoint.class,XYPoint.class) {
public XYPoint transform(XYPoint v) {
//v.x -= 1.0;
v.y *= 2.0;
return v;
}
});
Map<String,Double> coeffs = reg.calculateRegression();
Map<String,Double[]> bounds = reg.calculateBounds();
for(String title : coeffs.keySet()) {
Double[] b = bounds.get(title);
System.out.println(String.format("%s \t%.3f\t(%.3f, %.3f)",
title, coeffs.get(title),
b[0], b[1]));
}
} catch (IOException e) {
e.printStackTrace();
}
}
private DataFrame<M> frame;
private Predicted<M> dataY;
private Predictors<M> dataX;
private QRDecomposition qr;
private Matrix Rinv, betaHat, Vbeta;
private double s2, r2;
private String dataYVar;
private String[] dataXVars;
private RandomEngine engine;
private Normal ndist;
public DataRegression(DataFrame<M> f, String stmt) {
frame = f;
engine = new cern.jet.random.engine.DRand();
ndist = new Normal(0.0, 1.0, engine);
Vector<String> vs = parseStatement(stmt);
if(vs == null) {
throw new IllegalArgumentException(String.format("Couldn't parse statement \"%s\"",
stmt));
}
dataYVar = vs.get(0);
dataXVars = vs.subList(1, vs.size()).toArray(new String[vs.size()-1]);
dataY = new Predicted<M>(frame, dataYVar);
dataX = new Predictors<M>(frame, dataXVars);
}
public DataRegression(DataFrame<M> f, RegressionModel m) {
frame = f;
engine = new cern.jet.random.engine.DRand();
ndist = new Normal(0.0, 1.0, engine);
Field dvar = m.getDependentVariable();
Vector<Field> ivars = m.getIndependentVariables();
boolean intercept = m.hasInterceptVariable();
int plus = intercept ? 1 : 0;
dataYVar = dvar.getName();
dataXVars = new String[ivars.size() + plus];
int i = 0;
if(intercept) { dataXVars[i++] = "1"; }
for(; i < dataXVars.length; i++) {
dataXVars[i] = ivars.get(i-plus).getName();
}
dataY = new Predicted<M>(frame, dataYVar);
dataX = new Predictors<M>(frame, dataXVars);
}
public void filter(Predicate<M> p) {
frame.filter(p);
}
public void transform(Transformation<M,M> t) {
frame = frame.transform(t);
dataY = new Predicted<M>(frame, dataYVar);
dataX = new Predictors<M>(frame, dataXVars);
}
public Vector<String> getPredictorNames() {
return dataX.getColumnNames();
}
public Predictors<M> getPredictors() { return dataX; }
public Predicted<M> getPredicted() { return dataY; }
public Matrix getPredictorMatrix() {
return dataX.createMatrix();
}
public Matrix getPredictedVector() {
return dataY.createVector();
}
public DataFrame<M> getFrame() { return frame; }
public Map<String,Double> calculateRegression() {
calculate();
return collectCoefficients();
}
public Map<String,Double> collectCoefficients() {
HashMap<String,Double> map = new LinkedHashMap<String,Double>();
for(int i = 0; i < betaHat.getRowDimension(); i++) {
String name = dataX.getColumnName(i);
map.put(name, betaHat.get(i, 0));
}
return map;
}
public Map<String,Double[]> calculateBounds() {
HashMap<String,Double[]> map = new LinkedHashMap<String,Double[]>();
Vector<Double[]> bounds = sampleBetaBounds(100);
for(int i = 0; i < betaHat.getRowDimension(); i++) {
String name = dataX.getColumnName(i);
map.put(name, bounds.get(i));
}
return map;
}
public void calculate() {
calculate(null);
}
public void calculate(BitVector selector) {
calculate(selector, null);
}
public void calculate(BitVector selector, Map<String,Transformation<Double,Double>> transforms) {
Matrix y = dataY.createVector(selector);
Matrix X = dataX.createMatrix(selector, transforms);
calculate(X, y);
}
public static Matrix leastSquares(Matrix X, Matrix y) {
QRDecomposition qr = new QRDecomposition(X);
Matrix R = qr.getR();
Matrix Qtransy = qr.getQ().transpose().times(y);
Matrix betaHat = R.solve(Qtransy);
return betaHat;
}
public static double s2(Matrix X, Matrix y, Matrix betaHat) {
Matrix yhat = X.times(betaHat);
Matrix errors = y.minus(yhat);
int n = X.getRowDimension(), k = X.getColumnDimension();
double s2 = (errors.transpose().times(errors)).get(0, 0);
s2 /= (double)(n - k);
return s2;
}
public void calculate(Matrix X, Matrix y) {
qr = new QRDecomposition(X);
Matrix R = qr.getR();
Rinv = R.inverse();
Vbeta = Rinv.times(Rinv.transpose());
Matrix Qtransy = qr.getQ().transpose().times(y);
betaHat = R.solve(Qtransy);
Matrix yhat = X.times(betaHat);
Matrix errors = y.minus(yhat);
int n = X.getRowDimension(), k = X.getColumnDimension();
s2 = (errors.transpose().times(errors)).get(0, 0);
s2 /= (double)(n - k);
calculateR2(y, yhat);
}
private void calculateR2(Matrix y, Matrix yhat) {
double mean = 0.0;
for(int i = 0; i < y.getRowDimension(); i++) {
double yvalue = y.get(i, 0);
mean += yvalue;
}
mean /= (double)y.getRowDimension();
double SSE = 0.0, SST = 0.0, SSR = 0.0;
for(int i = 0; i < y.getRowDimension(); i++) {
double yvalue = y.get(i, 0);
double sstDiff = yvalue-mean;
double sseDiff = yvalue-yhat.get(i, 0);
double ssrDiff = yhat.get(i,0) - mean;
SST += (sstDiff * sstDiff);
SSE += (sseDiff * sseDiff);
SSR += (ssrDiff * ssrDiff);
}
r2 = 1.0 - (SSE / SST);
}
public Matrix getBetaHat() { return betaHat; }
public Matrix getVarBeta() { return Vbeta; }
public double getR2() { return r2; }
public double getS2() { return s2; }
public int getN() { return dataX.size(); }
public int getK() { return dataX.getNumColumns(); }
public Vector<Double[]> sampleBetaBounds(int iters) {
Vector<Double[]> v = new Vector<Double[]>();
for(int i = 0; i < getK(); i++) {
v.add(new Double[iters]);
}
for(int i = 0; i < iters; i++) {
double var = sampleVar();
Matrix beta = sampleBeta(var);
for(int j = 0; j < getK(); j++) {
v.get(j)[i] = beta.get(j, 0);
}
}
Vector<Double[]> bounds = new Vector<Double[]>();
int lower = (iters/4);
int upper = 3*(iters/4);
for(int j = 0; j < getK(); j++) {
Double[] sarray = v.get(j);
Arrays.sort(sarray);
Double[] b = new Double[] { sarray[lower], sarray[upper] };
bounds.add(b);
}
return bounds;
}
public Matrix sampleBeta(double var) {
Matrix beta = new Matrix(getK(), 1);
for(int i = 0; i < beta.getRowDimension(); i++) {
double n = ndist.nextDouble();
beta.set(i, 0, n);
}
double sd = Math.sqrt(var);
beta = Rinv.times(sd).times(beta).plus(betaHat);
return beta;
}
public double sampleVar() {
double diff = (double)(getN() - getK());
ChiSquare chiSquare = new cern.jet.random.ChiSquare(diff, engine);
double x = chiSquare.nextDouble();
return (diff * s2) / x;
}
/**
* @deprecated
* @param betaHat
* @return
*/
public double calculateS2(Matrix betaHat) {
// pg. 356 of Gelman
// n : number of datapoints
// k : number of predictors
double n = (double)frame.size();
double k = (double)dataX.getNumColumns();
double coeff = 1.0 / (n - k);
Matrix y = dataY.createVector();
Matrix X = dataX.createMatrix();
if(betaHat == null) { betaHat = calculateBetaHat(); }
Matrix half = y.minus(X.times(betaHat));
Matrix product = half.transpose().times(half);
double ret = coeff * product.get(0, 0);
return ret;
}
/**
* @deprecated
* @return
*/
public Matrix calculateBetaHat() {
// pg. 356 of Gelman
// n : number of datapoints
// k : number of predictors
// Xtrans : k x n
Matrix Xtrans = dataX.createMatrix().transpose();
// Vbeta : k x k
Matrix Vbeta = Xtrans.times(Xtrans.transpose());
Vbeta = Vbeta.inverse();
// ytransf : k x 1
Matrix ytransf = Xtrans.times(dataY.createVector());
// res : k x 1
Matrix res = Vbeta.times(ytransf);
return res;
}
private static Pattern stmtPattern = Pattern.compile(
"\\s*([^\\s~]+)\\s*~\\s*(.*)");
private static Vector<String> parseStatement(String stmt) {
Matcher m = stmtPattern.matcher(stmt);
Vector<String> v = null;
if(m.matches()) {
v = new Vector<String>();
String y = m.group(1);
v.add(y);
String preds = m.group(2);
String[] array = preds.split("\\s+");
boolean seenConstant = false;
boolean omitConstant = false;
boolean lastMinus = false;
for(int i = 0; i < array.length; i++) {
if(i % 2 == 1) {
if(array[i].equals("-")) {
lastMinus = true;
} else if(array[i].equals("+")) {
lastMinus = false;
} else {
return null;
}
} else {
if(array[i].equals("1")) {
seenConstant = true;
if(lastMinus) {
omitConstant = true;
} else {
v.add(array[i]);
}
} else {
v.add(array[i]);
}
}
}
if(!seenConstant && !omitConstant) {
v.add("1");
}
}
return v;
}
public static void printMatrix(Matrix m, PrintStream ps, int precision) {
String format = "%." + precision + "f";
ps.print(" \t");
for(int j = 0; j < m.getColumnDimension(); j++) {
if(j > 0) { ps.print(" "); }
ps.print(String.format(" %3d", j));
}
ps.println();
for(int i = 0; i < m.getRowDimension(); i++) {
ps.print(String.format("%3d\t", i));
for(int j = 0; j < m.getColumnDimension(); j++) {
if(j > 0) { ps.print(" "); }
ps.print(String.format(format, m.get(i, j)));
}
ps.println();
}
}
}