package nl.helixsoft.recordstream; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; /** * Transforms a {@link RecordStream} into another RecordStream by applying an aggregate functions to records that group together. * <p> * Usage examples: * <ul> * <li>Calculate the sum, maximum, or average for each group * <li>Flatten a gene annotation map (for example GO annotation). If you start with a gene + GO table, which each row containing only one GO identifier, and a single gene occurring in multiple rows, you can transform this * into a table where each gene is in only one row, and the GO annotations are concatenated with a delimiter. * </ul> * <p> * A Reducer has the following parameters: * <ul> * <li>A parent RecordStream, which will be the input before transformation * <li>A grouping column. Two consecutive records are grouped together if they have the same value in the grouping column * <li>An accumulator, which is a map of (column, {@link GroupFunc}) pairs. The GroupFunc is applied to each value of that column in the group, and accumulates the result. Grouping functions can do things like: calculate the min, max, sum or average. * Concatenate values with a separator. Or just return the first item. * </ul> * * <p> * I started writing this utility class to work around limitations of the SPARQL GROUP BY implementation of Virtuoso 6. * (AVG and SUM cast double to int, CONCAT just takes the first value or complains about maximum row length in temp table, etc...) */ public class Reducer extends AbstractRecordStream { private final RecordStream parent; private final Map<String, GroupFunc> accumulator; private Object prevValue = null; private Record row; private int idxGroupVar; private final RecordMetaData rmd; /** * * @param parent The recordstream that we want to reduce * @param groupVar * @param accumulator * @throws StreamException */ public Reducer (RecordStream parent, String groupVar, Map<String, GroupFunc> accumulator) throws StreamException { this.parent = parent; this.accumulator = accumulator; //TODO: should really make a defensive copy List<String> outHeaders = new ArrayList<String>(); outHeaders.add (groupVar); for (String header : accumulator.keySet()) { outHeaders.add(header); } idxGroupVar = parent.getMetaData().getColumnIndex(groupVar); row = parent.getNext(); prevValue = row.get(idxGroupVar); // Reset the accumulator at start resetAccumulator(); rmd = new DefaultRecordMetaData(outHeaders); } private Record writeAccumulator() { Object[] vals = new Object[rmd.getNumCols()]; vals[0] = prevValue; for (int i = 1; i < rmd.getNumCols(); ++i) { String colName = rmd.getColumnName(i); vals[i] = accumulator.get(colName).getResult(); } return new DefaultRecord(rmd, vals); } private void resetAccumulator() { for (String h : accumulator.keySet()) { accumulator.get(h).clear(); } } private void accumulate() { // accumulate a row for (String h : accumulator.keySet()) { accumulator.get(h).accumulate(row); } } @Override public Record getNext() throws StreamException { if (row == null) return null; // nothing more to return while (true) { accumulate(); row = parent.getNext(); // are we at the end? if (row == null) { return writeAccumulator(); } // has groupVar changed? if (!(row.get(idxGroupVar).toString().equals(prevValue.toString()))) { Record result = writeAccumulator(); // start with a fresh accumulator prevValue = row.get(idxGroupVar); resetAccumulator(); return result; } } } public static class GenericGroupFunc<T, U> extends AbstractGroupFunc { final BiFunction<T, U, T> function; GenericGroupFunc (String colName, T initial, BiFunction<T, U, T> function) { super(colName); this.function = function; this.initial = initial; } private T chain; private final T initial; public void accumulate (Record val) { int i = getIdx(val); U more = (U)val.get(i); // can't check at compile time but will throw ClassCastException at runtime if value is not the expected type. chain = function.apply (chain, more); } public T getResult() { return chain; } public void clear() { chain = initial; } } /** Count the items in the group */ public static class Count implements GroupFunc { private int count = 0; public void accumulate (Record val) { count++; } public Object getResult() { return count; } public void clear() { count = 0; } } /** Calculate the log(average) of floating point values. */ public static class LogAverageFloat extends AbstractGroupFunc { public LogAverageFloat(String col) { super(col); } private int count = 0; private float sum = 0; public void accumulate (Record val) { int i = getIdx(val); count++; sum += (Float)val.get(i); } public Object getResult() { double result = Math.log (sum / count); return result; } public void clear() { count = 0; sum = 0; } } //TODO: replace with "Reduce". // requires splitting Accumulator... /** Function to a apply to a group of values from a RecordStream. {@link #accumulate} is invoked with each element of the group in turn. The class should hold a running tally of the elements seen. After the last * element, {@link #getResult} is called followed by {@link #clear} to prepare for the next group. Implementations are used in {@link Reducer}, there are several canned implementations available. */ public interface GroupFunc { /** Accumulate another value */ public void accumulate(Record val); /** Get the result so far */ public Object getResult(); /** Reset the grouping function to the starting state, to prepare it for the next group. */ public void clear(); } /** Calculate the average of a group of Float values */ public static class AverageFloat extends AbstractGroupFunc { public AverageFloat(String col) { super(col); } private int count = 0; private float sum = 0; public void accumulate (Record val) { int i = getIdx(val); count++; sum += (Float)val.get(i); } public Object getResult() { return sum / count; } public void clear() { count = 0; sum = 0; } } /** Calculate the sum of a group of Float values */ public static class SumFloat extends AbstractGroupFunc { public SumFloat(String col) { super(col); } private float sum = 0; public void accumulate (Record val) { int i = getIdx(val); sum += (Float)val.get(i); } public Object getResult() { return sum; } public void clear() { sum = 0; } } /** Concatenate a group of String values with a separator between them. */ public static class Concatenate extends AbstractGroupFunc { private final String sep; public Concatenate(String colName, String sep) { super(colName); this.sep = sep; } private boolean first = true; private StringBuilder builder = new StringBuilder(); public void accumulate (Record val) { int i = getIdx(val); if (first) first = false; else builder.append(sep); builder.append(val.get(i).toString()); } public Object getResult() { return builder.toString(); } public void clear() { first = true; builder = new StringBuilder(); } } /** put the group of objects in a {@link List} */ public static class AsList extends AbstractGroupFunc { public AsList(String col) { super(col); } private List<Object> list; public void accumulate (Record val) { int i = getIdx(val); list.add (val.get(i)); } public Object getResult() { return list; } public void clear() { list = new ArrayList<Object>(); } } /** put the group of objects in a {@link Set} */ public static class AsSet extends AbstractGroupFunc { private Set<Object> set; public AsSet(String col) { super(col); } public void accumulate (Record val) { int i = getIdx(val); set.add (val.get(i)); } public Object getResult() { return set; } public void clear() { set = new HashSet<Object>(); } } public static abstract class AbstractGroupFunc implements GroupFunc { protected int idx = -1; private final String colName; public AbstractGroupFunc (String col) { this.colName = col; } protected int getIdx (Record val) { if (idx < 0) { // lazy initialization of idx RecordMetaData rs = val.getMetaData(); for (int col = 0; col < rs.getNumCols(); ++col) if (rs.getColumnName(col).equals(colName)) { idx = col; break; } } return idx; } } @Override public RecordMetaData getMetaData() { return rmd; } @Override public void close() { parent.close(); } }