package nl.helixsoft.stats; import java.io.PrintStream; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.SortedMap; import java.util.TreeMap; import javax.swing.table.TableModel; import com.google.common.collect.HashMultimap; import nl.helixsoft.recordstream.BiFunction; import nl.helixsoft.recordstream.Predicate; import nl.helixsoft.recordstream.Record; import nl.helixsoft.recordstream.RecordStream; import nl.helixsoft.recordstream.ReduceFunctions; import nl.helixsoft.util.ObjectUtils; /** * A collection of static functions to perform complex transformations on DataFrames. * <p> * Examples: join two dataframes (left, right, full join, inner join), * convert a dataframe from long format to wide format and back (like a pivot table in Excel), * and more. */ public abstract class DataFrameOperation { /** Different ways to handle non-existing keys */ public enum JoinType { LEFT, RIGHT, FULL, INNER }; public static DataFrame merge(DataFrame a, Map<?, ?> b, String onColumn, String valueColumnName) { DataFrame dfNew = DefaultDataFrame.createWithHeader(new String[] { onColumn, valueColumnName }); for (Map.Entry<?, ?> e : b.entrySet()) { dfNew.rbind(e.getKey(), e.getValue()); } return a.merge(dfNew, onColumn); } /** @deprecated use merge (a, b).onColumns(onThisColumn, onThatColumn).fullJoin().get() */ public static DataFrame merge (DataFrame a, DataFrame b, int onThisColumn, int onThatColumn) { return merge (a, b, onThisColumn, onThatColumn, JoinType.FULL); } private static int addLinear(Set<Integer> values, List<Integer> select, List<Integer> nullSelect) { for (int i : values) { select.add(i); nullSelect.add(null); } return values.size(); } private static int addCartesian(Set<Integer> leftValues, Set<Integer> rightValues, List<Integer> leftSelect, List<Integer> rightSelect) { for (int left : leftValues) { for (int right : rightValues) { leftSelect.add(left); rightSelect.add(right); } } return leftValues.size() * rightValues.size(); } private static int addCombination(Set<Integer> leftValues, Set<Integer> rightValues, JoinType joinType, List<Integer> leftSelect, List<Integer> rightSelect) { switch (joinType) { case FULL: if (leftValues.isEmpty() && rightValues.isEmpty()) return 0; else if (leftValues.isEmpty()) return addLinear (rightValues, rightSelect, leftSelect); else if (rightValues.isEmpty()) return addLinear (leftValues, leftSelect, rightSelect); else return addCartesian (leftValues, rightValues, leftSelect, rightSelect); case LEFT: if (leftValues.isEmpty()) return 0; else if (rightValues.isEmpty()) return addLinear (leftValues, leftSelect, rightSelect); else return addCartesian (leftValues, rightValues, leftSelect, rightSelect); case RIGHT: if (rightValues.isEmpty()) return 0; else if (leftValues.isEmpty()) return addLinear (rightValues, rightSelect, leftSelect); else return addCartesian (leftValues, rightValues, leftSelect, rightSelect); case INNER: if (leftValues.isEmpty() || rightValues .isEmpty()) return 0; return addCartesian (leftValues, rightValues, leftSelect, rightSelect); default: throw new IllegalArgumentException ("Invalid value for JoinType"); } } public static DataFrame merge (DataFrame left, DataFrame right, int onLeftColumn, int onRightColumn, JoinType joinType) { HashMultimap<Object, Integer> leftIndex = HashMultimap.create(); Set<Object> allKeys = new HashSet<Object>(); for (int i = 0; i < left.getRowCount(); ++i) { Object key = left.getValueAt(i, onLeftColumn); leftIndex.put (key, i); allKeys.add(key); } HashMultimap<Object, Integer> rightIndex = HashMultimap.create(); for (int i = 0; i < right.getRowCount(); ++i) { Object key = right.getValueAt(i, onRightColumn); rightIndex.put (key, i); allKeys.add(key); } List<Integer> leftSelect = new ArrayList<Integer>(); List<Integer> rightSelect = new ArrayList<Integer>(); List<Object> keyColumn = new ArrayList<Object>(); for (Object key : allKeys) { Set<Integer> leftValues = leftIndex.get(key); Set<Integer> rightValues = rightIndex.get(key); int numrows = addCombination (leftValues, rightValues, joinType, leftSelect, rightSelect); for (int i = 0; i < numrows; ++i) keyColumn.add(key); } int[] leftColumns = new int[left.getColumnCount()-1]; int pos = 0; for (int col = 0; col < left.getColumnCount(); ++col) if (col != onLeftColumn) leftColumns[pos++] = col; pos = 0; int[] rightColumns = new int[right.getColumnCount()-1]; for (int col = 0; col < right.getColumnCount(); ++col) if (col != onRightColumn) rightColumns[pos++] = col; return DataFrameOperation.cbind ( new ListColumn(keyColumn, left.getColumnHeader(onLeftColumn).toString()), left.cut(leftColumns).select(leftSelect), right.cut(rightColumns).select(rightSelect) ); } public static DataFrame cbind(Object... columnLike) { List<Column<?>> resultColumns = new ArrayList<Column<?>>(); for (Object o : columnLike) { if (o instanceof DataFrame) { DataFrame df = (DataFrame)o; for (int i = 0; i < df.getColumnCount(); ++i) { Column<?> col = new DefaultColumnView(df, i); resultColumns.add (col); } } else if (o instanceof Matrix) { //TODO } else if (o instanceof Object[]) { //TODO } else if (o instanceof Column) { resultColumns.add ((Column<?>)o); } else if (o instanceof List) { resultColumns.add (new ListColumn((List)o, null)); } else { throw new IllegalArgumentException("Not a suitable column-like class"); } } Integer uniformColumnLength = null; List<String> headers = new ArrayList<String>(); for (Column<?> col : resultColumns) { if (uniformColumnLength == null) uniformColumnLength = col.getSize(); else if (uniformColumnLength != col.getSize()) throw new IllegalArgumentException ("Not all input columns have equal size: " + uniformColumnLength + " versus " + col.getSize()); headers.add((String)col.getHeader()); } DataFrame result = DataFrameOperation.createWithHeader(headers.toArray(new String[headers.size()])); for (int row = 0; row < uniformColumnLength; ++row) { Object[] rowObj = new Object[resultColumns.size()]; for (int col = 0; col < resultColumns.size(); ++col) { rowObj[col] = resultColumns.get(col).get(row); } result.rbind(rowObj); } return result; } /** @deprecated use merge (a, b).onColumns(onThisColumn, onThatColumn).joinType().get() */ public static DataFrame merge (DataFrame a, DataFrame b, String onThisColumn, String onThatColumn, JoinType joinType) { return merge (a, b, a.getColumnIndex(onThisColumn), b.getColumnIndex(onThatColumn), joinType); } /** @deprecated use merge (a, b).onColumn(onColumn).fullJoin().get() */ public static DataFrame merge (DataFrame a, DataFrame b, String onColumn) { return merge (a, b, onColumn, onColumn, JoinType.FULL); } /** @deprecated use merge (a, b).onColumn(onColumn).joinType().get() */ public static DataFrame merge (DataFrame a, DataFrame b, String onColumn, JoinType joinType) { return merge(a, b, onColumn, onColumn, joinType); } /** Start a merge operation using a MergeBuilder. Usage example: * DataFrameOperation.merge (df1, df2).onColumn("ID").fullJoin().get(); */ public static MergeBuilder merge (DataFrame a, DataFrame b) { return new MergeBuilder(a, b); } /** Builder class to make it easy to set parameters for a merge (i.e. join) operation. */ public static class MergeBuilder { private final DataFrame a; private final DataFrame b; private JoinType joinType = JoinType.FULL; private int aCol = -1; private int bCol = -1; MergeBuilder (DataFrame a, DataFrame b) { this.a = a; this.b = b; } /** Set the column names that are used in the join */ public MergeBuilder onColumns(String aColName, String bColName) { aCol = a.getColumnIndex(aColName); bCol = b.getColumnIndex(bColName); return this; } /** Set the column indices that are used in the join */ public MergeBuilder onColumns(int aCol, int bCol) { this.aCol = aCol; this.bCol = bCol; return this; } /** set the join column name, in the case that it's the same column name for both data frames */ public MergeBuilder onColumn(String colName) { aCol = a.getColumnIndex(colName); bCol = b.getColumnIndex(colName); return this; } /** Set the join type to FULL */ public MergeBuilder fullJoin() { joinType = JoinType.FULL; return this; } /** Set the join type to INNER */ public MergeBuilder innerJoin() { joinType = JoinType.INNER; return this; } /** Set the join type to RIGHT */ public MergeBuilder rightJoin() { joinType = JoinType.RIGHT; return this; } /** Set the join type to LEFT */ public MergeBuilder leftJoin() { joinType = JoinType.LEFT; return this; } /** Finally perform the merge operation */ public DataFrame get() { assert (aCol >= 0) : "Join column not defined"; assert (bCol >= 0) : "Join column not defined"; return DataFrameOperation.merge (a, b, aCol, bCol, joinType); } } public static WideFormatBuilder wideFormat(DataFrame a) { return new WideFormatBuilder(a); } //TODO: groupBy and wideFormat are somewhat similar - both have grouping functions and grouping columns... public static GroupByBuilder groupBy (DataFrame a, String groupColumn) { return new GroupByBuilder (a, groupColumn); } public static class GroupByBuilder { private static class Aggregate { String col; BiFunction func; } private DataFrame parent; private String groupColumn; private List<Aggregate> aggs = new ArrayList<Aggregate>(); List<String> headers = new ArrayList<String>(); GroupByBuilder (DataFrame parent, String groupColumn) { this.parent = parent; this.groupColumn = groupColumn; headers.add(groupColumn); } public GroupByBuilder agg(String col, BiFunction<? extends Object, ? extends Object, ? extends Object> func) { Aggregate agg = new Aggregate(); agg.col = col; agg.func = func; aggs.add(agg); headers.add(col); return this; } public DataFrame get() { DataFrame sorted = parent.sort(groupColumn); int colNum = headers.size(); String[] header = headers.toArray(new String[colNum]); DataFrame result = DataFrameOperation.createWithHeader(header); Object prev = null; Object[] row = null; for (Record r : sorted.asRecordIterable()) { Object current = r.get(groupColumn); if (row == null) { row = new Object[colNum]; } else if (!ObjectUtils.safeEquals(current, prev)) { result.rbind (row); row = new Object[colNum]; } row[0] = current; for (int i = 0; i < aggs.size(); ++i) { Aggregate agg = aggs.get(i); Object more = r.get(agg.col); Object chain = row[i+1]; row[i+1] = agg.func.apply(chain, more); } prev = current; } if (row != null) result.rbind (row); return result; } } public static class WideFormatBuilder { private final DataFrame frame; private String[] columns; private String[] rows; private String rowNameField; private String colNameField; private String value; //TODO: option for aggregate functions... private BiFunction<Object, Object, Object> reduce; public WideFormatBuilder(DataFrame a) { frame = a; } public WideFormatBuilder withRowFactor(String... string) { rows = string; return this; } public WideFormatBuilder withColumnFactor(String... string) { columns = string; return this; } public WideFormatBuilder withRowNames(String string) { rowNameField = string; return this; } public WideFormatBuilder withColNames(String string) { colNameField = string; return this; } public WideFormatBuilder withValue(String string) { value = string; reduce = ReduceFunctions.FIRST; return this; } public WideFormatBuilder reduce(String string, BiFunction<Object, Object, Object> func) { value = string; reduce = func; return this; } public static class CompoundKey implements Comparable<CompoundKey> { final String[] data; public CompoundKey(int length) { data = new String[length]; } @Override public int compareTo(CompoundKey other) { for (int i = 0; i < data.length; ++i) { Comparable a = data[i]; Comparable b = other.data[i]; if (a == null && b == null) continue; if (a == null) return -1; if (b == null) return 1; int result = a.compareTo(b); if (result == 0) continue; return result; } return 0; } @Override public boolean equals(Object o) { if (o == null || o.getClass() != CompoundKey.class) return false; CompoundKey other = (CompoundKey)o; for (int i = 0; i < data.length; ++i) { String a = data[i]; String b = other.data[i]; if (a == null && b == null) continue; if (a == null || b == null) return false; if (!a.equals(b)) return false; } return true; } @Override public int hashCode() { int result = 0; for (int i = 0; i < data.length; ++i) { Object o = data[i]; if (o == null) { result = result * 5; } else { result = result ^ o.hashCode(); } } return result; } public void put(int ix, String o) { data [ix] = o; } @Override public String toString() { StringBuilder result = new StringBuilder(); String sep = "["; for (Object o : data) { result.append(sep); result.append("" + o); sep = ", "; } result.append ("]"); return result.toString(); // return data[0]; } public String get(int ix) { return data [ix]; } } public DataFrame get() { // first pass: build factors SortedMap<CompoundKey, Integer> rowFactors = new TreeMap<CompoundKey, Integer>(); SortedMap<CompoundKey, Integer> columnFactors = new TreeMap<CompoundKey, Integer>(); for (Record r : frame.asRecordIterable()) { CompoundKey rowKey = selectFields(r, rows); rowFactors.put(rowKey, 0); CompoundKey colKey = selectFields(r, columns); columnFactors.put(colKey, 0); } int i = 0; List<String> rowNames = new ArrayList<String>(); for (Map.Entry<CompoundKey, Integer> e : rowFactors.entrySet()) { e.setValue(i++); rowNames.add("" + e.getKey()); } i = 0; List<Object> colNames = new ArrayList<Object>(); for (Map.Entry<CompoundKey, Integer> e : columnFactors.entrySet()) { e.setValue(i++); colNames.add(e.getKey()); } Matrix<Double> m = new Matrix<Double>(columnFactors.size(), rowFactors.size()); for (Record r : frame.asRecordIterable()) { CompoundKey rowKey = selectFields(r, rows); CompoundKey colKey = selectFields(r, columns); Object v = r.get(value); int row = rowFactors.get(rowKey); int col = columnFactors.get(colKey); Object value = m.get(row, col); Object reduced = reduce.apply(value, v); m.set(row, col, reduced); if (rowNameField != null) rowNames.set(row, "" + r.get(rowNameField)); if (colNameField != null) colNames.set(col, "" + r.get(colNameField)); } return MatrixDataFrame.fromMatrix(m, new DefaultHeader(colNames, columns.length), rowNames); } private CompoundKey selectFields(Record r, String[] selectedFields) { CompoundKey rowKey = new CompoundKey(selectedFields.length); for (int i = 0; i < selectedFields.length; ++i) { rowKey.put(i, "" + r.get(selectedFields[i])); } return rowKey; } } public static DataFrame fromArray(String[] header, Object[][] objects) { DataFrame df = DataFrameOperation.createWithHeader(header); for (Object[] row : objects) { df.rbind (row); } return df; } public static DataFrame rbind(DataFrame df1, DataFrame df2) { // TODO Auto-generated method stub return null; } public static DataFrame columnSort(DataFrame in) { return columnSort(in, new Comparator<Column<?>>() { @Override public int compare(Column<?> o1, Column<?> o2) { Comparable<Comparable> a1 = (Comparable<Comparable>)o1.getHeader(); Comparable a2 = (Comparable)o2.getHeader(); if (a1 == null && a2 == null) return 0; if (a1 == null) return 1; if (a2 == null) return -1; return a1.compareTo(a2); } }); } public static DataFrame columnSort(DataFrame in, Comparator<Column<?>> comparator) { List<Column<?>> views = new ArrayList<Column<?>>(); for (int i = 0; i < in.getColumnCount(); ++i) { views.add (new DefaultColumnView(in, i)); } Collections.sort (views, comparator); return new ColumnBoundDataFrame (views, in); } public static void toTsv(PrintStream out, DataFrame df) { for (int col = 0; col < df.getColumnCount(); ++col) { out.print ("\t"); out.print (df.getColumnHeader(col).toString()); } out.println(); for (int row = 0; row < df.getRowCount(); ++row) { out.print (df.getRowName(row)); for (int col = 0; col < df.getColumnCount(); ++col) { out.print ("\t"); out.print (df.getValueAt(row, col)); } out.println(); } } /** * Render DataFrame as html. * Excludes the <table> tag. */ public static void toHtml(PrintStream out, DataFrame df) { out.println ("<thead>"); for (int h = 0; h < df.getColumnHeader().getSubHeaderCount(); ++h) { out.println ("<tr><th></th>"); for (int c = 0; c < df.getColumnCount(); ++c) { Object o = df.getColumnHeader().get(c); String hstr = o.toString(); out.print("<th>" + hstr + "</th>"); } out.println ("</tr>"); } out.println ("</thead>"); for (int r = 0; r < df.getRowCount(); ++r) { out.println ("<tr><th>" + df.getRowName(r) + "</th>"); for (int c = 0; c < df.getColumnCount(); ++c) { out.print("<td>" + df.getValueAt(r, c) + "</td>"); } out.println ("</tr>"); } } public static TableModel asTableModel(DataFrame df) { return new DataFrameTableModel (df, false); } public static DataFrame createWithHeader (String... header) { return DefaultDataFrame.createWithHeader(header); } public static DataFrame createFromRecordStream (RecordStream input) { return DefaultDataFrame.createFromRecordStream(input); } /** * Create a new dataframe by selecting those rows of the input dataframe that match the predicate function. */ public static DataFrame filter(DataFrame df, Predicate<Record> predicate) { List<Integer> rowSelection = new ArrayList<Integer>(); for (int row = 0; row < df.getRowCount(); ++row) { if (predicate.accept(df.getRow(row))) { rowSelection.add(row); } } return df.select(rowSelection); } }