/*
* Author: tdanford
* Date: Aug 27, 2008
*/
package org.seqcode.ml.regression;
import java.util.*;
import org.seqcode.gseutils.BitVector;
import org.seqcode.gseutils.models.*;
import java.lang.reflect.*;
import Jama.*;
public class Predicted<M extends Model> {
private DataFrame<M> frame;
private Field numericField;
public Predicted(DataFrame<M> f, String pfn) {
frame = f;
numericField = null;
Class<M> mcls = frame.getModelClass();
try {
Field mf = mcls.getField(pfn);
Class t = mf.getType();
if(Model.isSubclass(t, Number.class)) {
numericField = mf;
} else {
throw new IllegalArgumentException(String.format(
"Field %s is not numeric", pfn, t.getName()));
}
} catch (NoSuchFieldException e) {
e.printStackTrace();
throw new IllegalArgumentException(String.format(
"Unknown field name: %s", pfn));
}
}
public Matrix createVector(BitVector selector) {
int rows = selector != null ? selector.countOnBits() : frame.size();
Matrix m = new Matrix(rows, 1);
for(int i = 0, j = 0; j < frame.size(); j++) {
if(selector == null || selector.isOn(j)) {
M rowValue = frame.object(j);
try {
Number numValue = (Number)numericField.get(rowValue);
m.set(i, 0, numValue.doubleValue());
} catch (IllegalAccessException e) {
e.printStackTrace();
throw new IllegalStateException(String.format("Couldn't access field %s: %s",
numericField.getName(), e.getMessage()));
}
i++;
}
}
return m;
}
public Matrix createVector() {
int rows = frame.size();
Matrix m = new Matrix(rows, 1);
for(int i = 0; i < rows; i++) {
M rowValue = frame.object(i);
try {
Number numValue = (Number)numericField.get(rowValue);
m.set(i, 0, numValue.doubleValue());
} catch (IllegalAccessException e) {
e.printStackTrace();
throw new IllegalStateException(String.format("Couldn't access field %s: %s",
numericField.getName(), e.getMessage()));
}
}
return m;
}
}