/*
* Author: tdanford
* Date: Aug 19, 2008
*/
package org.seqcode.ml.regression;
import java.util.*;
import java.util.regex.*;
import org.seqcode.gseutils.Accumulator;
import org.seqcode.gseutils.Function;
import org.seqcode.gseutils.PackedBitVector;
import org.seqcode.gseutils.Predicate;
import org.seqcode.gseutils.models.Model;
import org.seqcode.gseutils.models.ModelFieldAnalysis;
import org.seqcode.gseutils.models.ModelInput;
import org.seqcode.gseutils.models.ModelInputIterator;
import org.seqcode.gseutils.models.ModelOutput;
import org.seqcode.gseutils.models.ModelInput.LineReader;
import org.seqcode.gseutils.models.ModelOutput.LineWriter;
import java.lang.reflect.*;
import java.io.*;
/**
* DataFrame holds a collection of Model objects -- intuitively, a DataFrame is a matrix
* of data, although the matrix data is not necessarily numerical. Each internal Model is a
* row in the matrix, and each field in the Model is a column. The DataFrame gives accessor
* methods for getting data by either row or column, for reading and writing tables to files,
* for assembling matrices of numerical data, etc.
*
* @author tdanford
* @param <T> The particular subclass of Model whose objects make up the rows of the frame.
*/
public class DataFrame<T extends Model> {
private Class<T> cls;
private ArrayList<T> objects;
private File file;
private ModelFieldAnalysis<T> fieldAnalysis;
/**
* Reads the frame data from a particular file.
*
* @param cls
* @param file
* @throws IOException
*/
public DataFrame(Class<T> cls, File file) throws IOException {
this.cls = cls;
this.file = file;
fieldAnalysis = new ModelFieldAnalysis<T>(cls);
objects = parse(file, true);
}
/**
* Reads the frame data from a particular file.
*
* @param cls
* @param file
* @throws IOException
*/
public DataFrame(Class<T> cls, File file, String... cols) throws IOException {
this.cls = cls;
this.file = file;
fieldAnalysis = new ModelFieldAnalysis<T>(cls);
objects = parse(file, true, cols);
}
public DataFrame(Class<T> cls, File file, boolean header, String... cols) throws IOException {
this.cls = cls;
this.file = file;
fieldAnalysis = new ModelFieldAnalysis<T>(cls);
objects = parse(file, header, cols);
}
/**
* Creates a DataFrame object from the given iterator over values.
*
* @param cls
* @param objs
*/
public DataFrame(Class<T> cls, Iterator<T> objs) {
this.cls = cls;
objects = new ArrayList<T>();
while(objs.hasNext()) {
objects.add(objs.next());
}
this.file = null;
fieldAnalysis = new ModelFieldAnalysis<T>(cls);
}
/**
* Creates a DataFrame object from the given collection of values.
*
* @param cls
* @param objs
*/
public DataFrame(Class<T> cls, Collection<T> objs) {
this.cls = cls;
objects = new ArrayList<T>(objs);
this.file = null;
fieldAnalysis = new ModelFieldAnalysis<T>(cls);
}
/**
* Creates an empty DataFrame object.
*
* @param cls
*/
public DataFrame(Class<T> cls) {
this.cls = cls;
objects = new ArrayList<T>();
file = null;
fieldAnalysis = new ModelFieldAnalysis<T>(cls);
}
public void loadJSON(InputStream is) {
ModelInput<T> input = new ModelInput.LineReader<T>(cls, is);
ModelInputIterator<T> itr = new ModelInputIterator<T>(input);
addObjects(itr);
}
public void saveJSON(OutputStream os) {
ModelOutput<T> output = new ModelOutput.LineWriter<T>(os);
for(T val : objects) { output.writeModel(val); }
}
public Iterator<T> iterator() { return objects.iterator(); }
public PackedBitVector getMask(Predicate<T> pred) {
PackedBitVector pbv = new PackedBitVector(objects.size());
for(int i = 0; i < objects.size(); i++) {
if(pred.accepts(objects.get(i))) {
pbv.turnOnBit(i);
}
}
return pbv;
}
/**
* Removes any item that satisfies the given predicate from this DataFrame.
* Extracted items are collated in a separate DataFrame<T>, which is then
* returned from this method.
*
* @param pred The indicator for which elements are to be removed.
* @return A DataFrame of the removed elements.
*/
public DataFrame<T> extract(Predicate<T> pred) {
DataFrame<T> extracted = new DataFrame<T>(cls);
Iterator<T> itr = objects.iterator();
while(itr.hasNext()) {
T value = itr.next();
if(pred.accepts(value)) {
itr.remove();
extracted.addObject(value);
}
}
return extracted;
}
/**
* A Transformation<T,S> object turns objects of type T into objects of type S.
* The transform() method takes a transformation, where T is the model-type of
* this DataFrame, and returns a new DataFrame of S objects, where each object
* corresponds to a transformed version of the original T from this frame.
*
* @param <S>
* @param trans
* @return
*/
public <S extends Model> DataFrame<S> transform(Transformation<T,S> trans) {
DataFrame<S> df = new DataFrame<S>(trans.toClass());
for(T val : objects) {
df.addObject(trans.transform(val));
}
return df;
}
public DataFrame<T> extend(DataFrame<T> df) {
objects.addAll(df.objects);
return this;
}
public <R extends Model, S extends Model>
DataFrame<R> join(Class<R> rClass,
DataFrame<S> outerFrame,
String fieldName,
String innerName,
String outerName) {
ModelFieldAnalysis<S> analysisS = outerFrame.fieldAnalysis;
Field fieldS = analysisS.findField(fieldName);
ModelFieldAnalysis<T> analysisT = fieldAnalysis;
Field fieldT = analysisT.findField(fieldName);
ModelFieldAnalysis<R> analysisR = new ModelFieldAnalysis<R>(rClass);
Field rJoin = analysisR.findField(fieldName);
Field rInner = analysisR.findField(innerName);
Field rOuter = analysisR.findField(outerName);
if(fieldS == null || fieldT == null || rJoin == null) {
throw new IllegalArgumentException(fieldName);
}
LinkedList<R> joinedValues = new LinkedList<R>();
HashMap<Object,ArrayList<S>> hashS = new HashMap<Object,ArrayList<S>>();
for(int i = 0; i < outerFrame.size(); i++) {
S outerValue = outerFrame.object(i);
try {
Object key = fieldS.get(outerValue);
if(!hashS.containsKey(key)) {
hashS.put(key, new ArrayList<S>());
}
hashS.get(key).add(outerValue);
} catch (IllegalAccessException e) {
e.printStackTrace();
throw new IllegalArgumentException(String.format("Field %s " +
"was illegally accessed in Model %s", fieldName,
outerFrame.getModelClass().getSimpleName()));
}
}
for(int i = 0; i < size(); i++) {
T innerValue = object(i);
try {
Object key = fieldT.get(innerValue);
if(hashS.containsKey(key)) {
for(S outerValue : hashS.get(key)) {
R joined = (R)rClass.newInstance();
rJoin.set(joined, key);
if(rOuter != null) {
rOuter.set(joined, outerValue);
}
if(rInner != null) {
rInner.set(joined, innerValue);
}
joinedValues.add(joined);
}
}
} catch (IllegalAccessException e) {
e.printStackTrace();
throw new IllegalArgumentException(String.format("Field %s " +
"was illegally accessed in Model %s", fieldName,
getModelClass().getSimpleName()));
} catch (InstantiationException e) {
e.printStackTrace();
throw new IllegalArgumentException(String.format(
"Couldn't instantiate Model class %s",
rClass.getSimpleName()));
}
}
DataFrame<R> joinFrame = new DataFrame<R>(rClass, joinedValues);
return joinFrame;
}
public DataFrame<T> filter(Predicate<T> p) {
DataFrame<T> df = new DataFrame<T>(cls);
for(T val : objects) {
if(p.accepts(val)) {
df.addObject(val);
}
}
return df;
}
public void apply(Accumulator<T> acc) {
for(T val : objects) {
acc.accumulate(val);
}
}
/**
* Saves the data back out to a particular file -- this file then becomes the DataFrame's
* "default" file, and is the target when future calls are made to the save() method
* (no parameters).
*
* @param f
* @throws IOException
*/
public void save(File f) throws IOException {
writeTable(objects, fieldAnalysis.getFieldNames(), f);
file = f;
}
/**
* Saves the data back out to the "default" file (either the file from which the data
* was loaded, or the file that was given as a parameter to the last call to save(File).)
*
* @throws IOException
*/
public void save() throws IOException {
save(file);
}
/**
* Adds new rows to the DataFrame object.
*
* @param objs
*/
public void addObjects(Iterator<T> objs) {
while(objs.hasNext()) {
addObject(objs.next());
}
}
public File getFile() { return file; }
public Class<T> getModelClass() { return cls; }
public void addObjects(Collection<T> objs) { objects.addAll(objs); }
public void addObject(T obj) { objects.add(obj); }
public Vector<String> getFields() { return fieldAnalysis.getFieldNames(); }
public T object(int i) { return objects.get(i); }
public int size() { return objects.size(); }
/**
* For a particular named field in the model class, this returns a Set of all distinct
* values of that field among the rows of this DataFrame.
*
* @param <FT>
* @param fieldName
* @return
*/
public <FT> Set<FT> fieldValues(String fieldName) {
HashSet<FT> values = new HashSet<FT>();
try {
Field field = cls.getField(fieldName);
for(T obj : objects) {
FT objFieldValue = (FT)field.get(obj);
values.add(objFieldValue);
}
} catch (NoSuchFieldException e) {
e.printStackTrace();
throw new IllegalArgumentException(fieldName);
} catch (IllegalAccessException e) {
e.printStackTrace();
}
return values;
}
public Double pearsonCorrelation(String xCol, String yCol) {
Double[] x = asVector(xCol);
Double[] y = asVector(yCol);
double xSum = 0.0, ySum = 0.0, xSqSum = 0.0, ySqSum = 0.0, xySum = 0.0;
int N = 0;
for(int i = 0; i < x.length; i++) {
if(x[i] != null && y[i] != null) {
N += 1;
xSum += x[i];
ySum += y[i];
xSqSum += (x[i] * x[i]);
ySqSum += (y[i] * y[i]);
xySum += (x[i] * y[i]);
}
}
/**
* (n \sum xy - \sum x \sum y)
* r_xy = -------------------------------------------------------------
* sqrt(n \sum x^2 - (\sum x)^2) * sqrt(n \sum y^2 - (\sum y)^2)
*/
if(N == 0) { throw new IllegalStateException("No values for correlation."); }
double n = (double)N;
double numer = n * xySum - (xSum * ySum);
double denom1 = Math.sqrt(n * xSqSum - (xSum*xSum));
double denom2 = Math.sqrt(n * ySqSum - (ySum * ySum));
double denom = denom1 * denom2;
return numer / denom;
}
public Double mean(String col) {
int count = 0;
Double sum = null;
Field field = fieldAnalysis.findField(col);
if(field != null && Model.isSubclass(field.getType(), Number.class)) {
for(T obj : objects) {
try {
Number num = (Number)field.get(obj);
if(num != null) {
count += 1;
sum = (sum == null ? num.doubleValue() : sum + num.doubleValue());
}
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
return count > 0 ? sum / (double)count : sum;
}
public Double variance(String col) {
Double mean = mean(col);
if(mean == null) { return null; }
return squaredError(col, new Function.Constant<T,Double>(mean));
}
public Double squaredError(String col, Function<T,Double> baseliner) {
int count = 0;
Double sum = null;
Field field = fieldAnalysis.findField(col);
if(field != null && Model.isSubclass(field.getType(), Number.class)) {
for(T obj : objects) {
try {
Number num = (Number)field.get(obj);
if(num != null) {
count += 1;
double val = num.doubleValue();
double baseline = baseliner.valueAt(obj);
double diff = val-baseline;
double diff2 = diff*diff;
sum = (sum == null ? diff2 : sum + diff2);
}
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
return count > 0 ? sum / (double)count : sum;
}
public Double[][] asMatrix(String... cols) {
return asMatrix(objects, cols);
}
public Double[] asVector(String fieldName) {
return asVector(objects, fieldName);
}
public Double[] asVector(ArrayList<T> rows, String fieldName) {
Double[] array = new Double[rows.size()];
for(int i = 0; i < rows.size(); i++) {
T row = rows.get(i);
try {
Field field = cls.getField(fieldName);
Class type = field.getType();
if(Model.isSubclass(type, Number.class)) {
Object value = field.get(row);
if(value == null){
array[i] = null;
} else {
array[i] = ((Number)value).doubleValue();
}
} else {
array[i] = null;
}
} catch (NoSuchFieldException e) {
e.printStackTrace();
array[i] = null;
} catch (IllegalAccessException e) {
e.printStackTrace();
array[i] = null;
}
}
return array;
}
public Double[][] asMatrix(ArrayList<T> rows, String... cols) {
Vector<String> colFields = new Vector<String>();
for(int i = 0; i < cols.length; i++) { colFields.add(cols[i]); }
return asMatrix(rows, colFields);
}
public Double[][] asMatrix(ArrayList<T> rows, Vector<String> colFields) {
Double[][] array = new Double[rows.size()][colFields.size()];
for(int i = 0; i < rows.size(); i++) {
T row = rows.get(i);
for(int j = 0; j < colFields.size(); j++) {
String fieldName = colFields.get(j);
try {
Field field = cls.getField(fieldName);
Class type = field.getType();
if(Model.isSubclass(type, Number.class)) {
Object value = field.get(row);
if(value == null){
array[i][j] = null;
} else {
array[i][j] = ((Number)value).doubleValue();
}
} else {
array[i][j] = null;
}
} catch (NoSuchFieldException e) {
e.printStackTrace();
array[i][j] = null;
} catch (IllegalAccessException e) {
e.printStackTrace();
array[i][j] = null;
}
}
}
return array;
}
/** Parsing Code **/
private void writeTable(Collection<T> lines, Vector<String> fields, File f) throws IOException {
PrintStream ps = new PrintStream(new FileOutputStream(f));
for(int i = 0; i < fields.size(); i++) {
if(i > 0) { ps.print("\t"); }
ps.print(fields.get(i));
}
ps.println();
for(T line : lines) {
writeLine(line, fields, ps);
}
ps.close();
}
private ArrayList<T> parse(File f, boolean header) throws IOException {
return parse(f, header, (String[])null);
}
private ArrayList<T> parse(File f, boolean header, String... fieldArray) throws IOException {
ArrayList<T> values = new ArrayList<T>();
BufferedReader br = new BufferedReader(new FileReader(f));
Vector<String> fields = new Vector<String>();
String line = null;
String[] array = null;
String sep = "\\s+";
if(header) {
line = br.readLine();
}
if(fieldArray != null && fieldArray.length > 0) {
for(int i = 0; i < fieldArray.length; i++) {
fields.add(fieldArray[i]);
}
} else if (header && line != null) {
array = line.split(sep);
for(int i = 0; i < array.length; i++) { fields.add(array[i]); }
}
Vector<Boolean> quoted = new Vector<Boolean>();
for(String fo : fields) {
boolean isquoted = fieldAnalysis.getStaticSwitch(String.format("quote_%s", fo), false);
quoted.add(!isquoted);
}
int ignored = 0;
while((line = br.readLine()) != null) {
line = line.trim();
if(line.length() > 0) {
array = line.split(sep);
T value = parseLine(array, fields, quoted);
if(value != null) {
values.add(value);
} else {
ignored += 1;
}
}
}
System.out.println(String.format("Parsed %d lines from %s", values.size(), f.getName()));
if(ignored > 0) {
System.err.println(String.format("Ignored %d lines from %s", ignored, f.getName()));
}
br.close();
return values;
}
private void writeLine(T modelObject, Vector<String> fieldOrder, PrintStream ps) {
Class cls = modelObject.getClass();
int fi = 0;
for(String fieldName : fieldOrder) {
try {
Field field = cls.getField(fieldName);
Object value = field.get(modelObject);
if(fi != 0) { ps.print("\t"); }
if(value != null) {
ps.print(value.toString());
} else {
ps.print("NA");
}
} catch (NoSuchFieldException e) {
e.printStackTrace();
ps.print("NA");
} catch (IllegalAccessException e) {
e.printStackTrace();
ps.print("NA");
}
fi += 1;
}
if(fi > 0) { ps.println(); }
}
private static Pattern quotePattern = Pattern.compile("^\\s*\"(.*)\"\\s*$");
private String extractQuoted(String quoted) {
Matcher m = quotePattern.matcher(quoted);
if(m.matches()) {
String value = m.group(1);
return value;
} else {
return quoted;
}
}
public T parseLine(String[] array, Vector<String> fieldOrder, Vector<Boolean> quoted) {
if(fieldOrder.size() > array.length) {
String arraystr = "";
for(int i = 0; i < array.length; i++) {
arraystr += array[i] + " ";
}
String msg = String.format("fieldOrder.size() == %d (%s) exceeded array.length == %d : %s",
fieldOrder.size(), fieldOrder.toString(), array.length, arraystr);
throw new IllegalArgumentException(msg);
}
T val = null;
try {
val = cls.newInstance();
for(int i = 0; i < fieldOrder.size(); i++) {
String fieldName = fieldOrder.get(fieldOrder.size()-1-i);
String valueString = array[array.length-1-i];
if(quoted.get(i)) {
valueString = extractQuoted(valueString);
}
boolean isNA = valueString.equals("NA");
try {
Field f = cls.getField(fieldName);
Class type = f.getType();
if(Model.isSubclass(type, Double.class)) {
try {
Double fieldValue = isNA ? null : Double.parseDouble(valueString);
f.set(val, fieldValue);
} catch(NumberFormatException nfe) {
f.set(val, null); // missing value.
}
} else if(Model.isSubclass(type, Boolean.class)) {
Boolean fieldValue = isNA ? null : Boolean.parseBoolean(valueString);
f.set(val, fieldValue);
} else if (Model.isSubclass(type, Integer.class)) {
try {
Integer fieldValue = isNA ? null : Integer.parseInt(valueString);
f.set(val, fieldValue);
} catch(NumberFormatException nfe) {
f.set(val, null); // missing value.
}
} else if (Model.isSubclass(type, String.class)) {
String fieldValue = valueString;
f.set(val, fieldValue);
} else {
System.err.println(String.format(
"Field %s has unsupported parsing type %s",
f.getName(), type.getName()));
}
} catch (NoSuchFieldException e) {
}
}
} catch (InstantiationException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
}
return val;
}
}
class Test extends Model {
public String Weekend, Decision, Weather, Parents, Money;
}