/*
* Author: tdanford
* Date: Aug 27, 2008
*/
package org.seqcode.ml.regression;
import java.util.*;
import java.util.regex.*;
import org.seqcode.gseutils.BitVector;
import org.seqcode.gseutils.models.*;
import java.lang.reflect.*;
import Jama.*;
public class Predictors<M extends Model> {
private static Pattern interactionPattern = Pattern.compile("([^:]+):(.+)");
private DataFrame<M> frame;
private boolean hasConstant;
private Vector<Field> numeric; // the list of 'numeric' variables.
private Vector<Field> factor; // the list of factor variables.
private Map<Field,Vector<Object>> factorCodes; // the unique values that can be taken by each factor variable.
// Each interaction is the product of one or more variables, some of which may be factors and the others 'numeric'.
private Vector<Interaction> interactions;
// Each InteractionValue is a product of values from
// the variables in an interaction which are 'factor' variables.
private Map<Interaction,Vector<InteractionValue>> interactionValues;
private int cols;
private Vector<String> columnNames;
public Predictors(DataFrame<M> f, String... fs) {
frame = f;
numeric = new Vector<Field>();
factor = new Vector<Field>();
interactions = new Vector<Interaction>();
hasConstant = false;
factorCodes = new HashMap<Field,Vector<Object>>();
interactionValues = new HashMap<Interaction,Vector<InteractionValue>>();
columnNames = new Vector<String>();
cols = 0;
Class<M> mcls = frame.getModelClass();
Set<String> seenFields = new HashSet<String>();
ModelFieldAnalysis<M> analysis = new ModelFieldAnalysis<M>(f.getModelClass());
for(int i = 0; i < fs.length; i++) {
if(seenFields.contains(fs[i])) {
throw new IllegalArgumentException(String.format(
"Duplicate field name: %s", fs[i]));
}
if(fs[i].equals("1")) {
hasConstant = true;
cols += 1;
} else {
try {
Matcher inMatcher = interactionPattern.matcher(fs[i]);
Field mf = analysis.findField(fs[i]);
if(mf != null) {
Class t = mf.getType();
if(Model.isSubclass(t, Number.class)) {
addPredictor(fs[i]);
} else if (Model.isSubclass(t, String.class)) {
addFactor(fs[i]);
} else {
throw new IllegalArgumentException(String.format(
"Field %s is not a regression-ready predictor", fs[i]));
}
} else if (inMatcher.matches()) {
String[] array = fs[i].split(":");
addInteraction(array);
} else {
throw new IllegalArgumentException(String.format(
"Unknown field name: %s", fs[i]));
}
} catch(NoSuchFieldException e) {
throw new IllegalArgumentException(String.format(
"Unknown field name: %s", fs[i]));
}
}
seenFields.add(fs[i]);
}
if(hasConstant) {
columnNames.insertElementAt("(Intercept)", 0);
}
}
public int size() {
return frame.size();
}
public void addConstant() {
if(!hasConstant) {
hasConstant = true;
columnNames.insertElementAt("(Intercept)", 0);
}
}
public void addInteraction(String... fns) throws NoSuchFieldException {
ModelFieldAnalysis<M> mfa = new ModelFieldAnalysis<M>(frame.getModelClass());
Field[] fields = mfa.findFields(fns);
for(int i = 0; i < fields.length; i++) {
if(fields[i] == null) { throw new NoSuchFieldException(fns[i]); }
}
Interaction inter = new Interaction(fields);
if(interactions.contains(inter)) {
throw new IllegalArgumentException(
String.format("Cannot add the same interaction %s twice.", inter.toString()));
}
interactions.add(inter);
String[] factorFields = inter.findFactorFieldNames();
Vector<InteractionValue> values = inter.allInteractionValues(factorFields);
interactionValues.put(inter, values);
cols += values.size();
for(InteractionValue value : values) {
String colName = String.format("%s(%s)", inter.toString(), value.toString());
columnNames.add(colName);
}
}
public void addPredictor(String fn) throws NoSuchFieldException {
Field f = frame.getModelClass().getField(fn);
Class t = f.getType();
if(Model.isSubclass(t, Double.class)) {
numeric.add(f);
} else if (Model.isSubclass(t, Integer.class)) {
numeric.add(f);
} else {
throw new NoSuchFieldException(String.format(
"%s is not a numeric field (%s)", fn, t.getName()));
}
cols += 1;
columnNames.add(fn);
}
public Set<String> findFactorValues(String fn) {
TreeSet<String> values = new TreeSet<String>();
for(Object v : frame.fieldValues(fn)) {
values.add((String)v);
}
values.remove(values.first());
return values;
}
public void addFactor(String fn) throws NoSuchFieldException {
Field f = frame.getModelClass().getField(fn);
Class t = f.getType();
if(Model.isSubclass(t, String.class)) {
Set<String> values = findFactorValues(fn);
factor.add(f);
//factorCodes.put(f, new Vector(new TreeSet<String>(values)));
factorCodes.put(f, new Vector(values));
cols += values.size();
for(String obj : values) {
columnNames.add(String.format("%s(%s)", fn, obj));
}
} else {
throw new IllegalArgumentException(String.format("%s is not a valid factor-field.", fn));
}
}
public int getNumColumns() { return cols; }
public String getColumnName(int i) { return columnNames.get(i); }
public Matrix createMatrix() {
return createMatrix(null);
}
public Matrix createMatrix(BitVector selector) {
return createMatrix(selector, null);
}
public Matrix createMatrix(BitVector selector, Map<String,Transformation<Double,Double>> transforms) {
int rows = selector != null ? selector.countOnBits() : frame.size();
Matrix m = new Matrix(rows, cols);
int cidx = 0;
if(hasConstant) {
for(int i = 0, j = 0; j < frame.size(); j++) {
if(selector == null || selector.isOn(j)) {
m.set(i, cidx, 1.0);
i += 1;
}
}
cidx += 1;
}
for(Field f : numeric) {
Transformation<Double,Double> transform =
transforms != null && transforms.containsKey(f.getName()) ?
transforms.get(f.getName()) : null;
for(int i = 0, j = 0; j < frame.size(); j++) {
if(selector == null || selector.isOn(j)) {
M rowValue = frame.object(j);
try {
Double numValue = ((Number)f.get(rowValue)).doubleValue();
if(transform != null) {
numValue = transform.transform(numValue);
}
m.set(i, cidx, numValue);
} catch (IllegalAccessException e) {
e.printStackTrace();
throw new IllegalStateException(String.format("Couldn't access field %s: %s",
f.getName(), e.getMessage()));
}
i+=1;
}
}
cidx += 1;
}
for(Field f : factor) {
Vector<Object> factorValues = factorCodes.get(f);
for(int i = 0, j = 0; j < frame.size(); j++) {
if(selector==null || selector.isOn(j)) {
M rowValue = frame.object(j);
try {
Object factorValue = f.get(rowValue);
int idx = factorValues.indexOf(factorValue);
if(idx != -1) {
m.set(i, cidx+idx, 1.0);
}
} catch (IllegalAccessException e) {
e.printStackTrace();
throw new IllegalStateException(String.format("Couldn't access field %s: %s",
f.getName(), e.getMessage()));
}
i += 1;
}
}
cidx += factorValues.size();
}
for(Interaction in : interactions) {
Vector<InteractionValue> values = interactionValues.get(in);
for(int i = 0, j = 0; j < frame.size(); j++) {
if(selector==null || selector.isOn(j)) {
M rowValue = frame.object(j);
InteractionValue inValue = in.calculateInteractionValue(rowValue);
Double numValue = in.calculatePredictor(rowValue);
int idx = values.indexOf(inValue);
if(idx != -1) {
m.set(i, cidx+idx, numValue);
}
i+=1;
}
}
cidx += values.size();
}
return m;
}
private class Interaction {
public Set<Field> fields;
public Interaction(Field... fs) {
fields = new HashSet<Field>();
for(int i = 0; i < fs.length; i++) {
fields.add(fs[i]);
}
}
public int hashCode() {
int code = 17;
for(Field f : fields) {
code += f.hashCode();
}
code *= 37;
return code;
}
public String toString() {
StringBuilder sb = new StringBuilder();
for(Field f : fields) {
if(isFactorField(f)) {
if(sb.length() > 0) { sb.append(":"); }
sb.append(f.getName());
}
}
for(Field f : fields) {
if(!isFactorField(f)) {
if(sb.length() > 0) { sb.append(":"); }
sb.append(f.getName());
}
}
return sb.toString();
}
public boolean equals(Object o) {
if(!(o instanceof Predictors.Interaction)) { return false; }
Interaction in = (Interaction)o;
if(fields.size() != in.fields.size()) { return false; }
for(Field f : fields) {
if(!(in.fields.contains(f))) {
return false;
}
}
return true;
}
private boolean isNumericField(Field f) {
return Model.isSubclass(f.getType(), Number.class);
}
private boolean isFactorField(Field f) {
return Model.isSubclass(f.getType(), String.class);
}
public InteractionValue calculateInteractionValue(Object o) {
InteractionValue value = new InteractionValue();
for(Field f : fields) {
if(isFactorField(f)) {
try {
String v = (String) f.get(o);
value.values.add(v);
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
return value;
}
public Double calculatePredictor(Object o) {
Double value = 1.0;
for(Field f : fields) {
if(isNumericField(f)) {
try {
Number n = (Number)f.get(o);
value *= n.doubleValue();
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
return value;
}
public String[] findFactorFieldNames() {
Vector<String> fs = new Vector<String>();
for(Field f : fields) {
if(isFactorField(f)) {
fs.add(f.getName());
}
}
return fs.toArray(new String[fs.size()]);
}
public String[] findNumericFieldNames() {
Vector<String> fs = new Vector<String>();
for(Field f : fields) {
if(isNumericField(f)) {
fs.add(f.getName());
}
}
return fs.toArray(new String[fs.size()]);
}
public Vector<InteractionValue> allInteractionValues(String[] names) {
Vector<InteractionValue> vv = new Vector<InteractionValue>();
vv.add(new InteractionValue());
String[] ff = findFactorFieldNames();
for(int i = 0; i < ff.length; i++) {
//Set values = findFactorValues(ff[i]);
Set values = frame.fieldValues(ff[i]);
vv = appendValues(vv, values);
}
// Analogous to taking out the very first value of a set of factor values.
vv.remove(0);
return vv;
}
private Vector<InteractionValue> appendValues(Vector<InteractionValue> prev, Set vals) {
Vector<InteractionValue> newv = new Vector<InteractionValue>();
for(InteractionValue v : prev) {
newv.addAll(v.extend(vals));
}
return newv;
}
}
private class InteractionValue {
public Vector values;
public InteractionValue() {
values = new Vector();
}
public InteractionValue(InteractionValue v ) {
values = new Vector(v.values);
}
public InteractionValue(InteractionValue v, Object o) {
values = new Vector(v.values);
values.add(o);
}
public Vector<InteractionValue> extend(Set vals) {
Vector<InteractionValue> ivs = new Vector<InteractionValue>();
for(Object v : vals) {
ivs.add(new InteractionValue(this, v));
}
return ivs;
}
public String toString() {
StringBuilder sb = new StringBuilder();
for(int i = 0; i < values.size(); i++) {
if(i > 0) { sb.append("_"); }
sb.append(values.get(i).toString());
}
return sb.toString();
}
public int hashCode() {
int code = 17;
for(Object v : values) { code += v.hashCode(); code *= 37; }
return code;
}
public boolean equals(Object o) {
if(!(o instanceof Predictors.InteractionValue)) { return false; }
InteractionValue iv = (InteractionValue)o;
if(iv.values.size() != values.size()) { return false; }
for(int i = 0; i < values.size(); i++) {
if(!values.get(i).equals(iv.values.get(i))) { return false; }
}
return true;
}
}
public Vector<String> getColumnNames() {
return columnNames;
}
}