/* * Joinery -- Data frames for Java * Copyright (c) 2014, 2015 IBM Corp. * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package joinery.impl; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; import joinery.DataFrame; import joinery.DataFrame.Aggregate; import joinery.DataFrame.Function; import joinery.DataFrame.KeyFunction; import joinery.impl.Transforms.CumulativeFunction; public class Grouping implements Iterable<Map.Entry<Object, SparseBitSet>> { private final Map<Object, SparseBitSet> groups = new LinkedHashMap<>(); private final Set<Integer> columns = new LinkedHashSet<>(); public Grouping() { } public <V> Grouping(final DataFrame<V> df, final KeyFunction<V> function, final Integer ... columns) { final Iterator<List<V>> iter = df.iterator(); for (int r = 0; iter.hasNext(); r++) { final List<V> row = iter.next(); final Object key = function.apply(row); SparseBitSet group = groups.get(key); if (group == null) { group = new SparseBitSet(); groups.put(key, group); } group.set(r); } for (final int column : columns) { this.columns.add(column); } } public <V> Grouping(final DataFrame<V> df, final Integer ... columns) { this( df, columns.length == 1 ? new KeyFunction<V>() { @Override public Object apply(final List<V> value) { return value.get(columns[0]); } } : new KeyFunction<V>() { @Override public Object apply(final List<V> value) { final List<Object> key = new ArrayList<>(columns.length); for (final int column : columns) { key.add(value.get(column)); } return Collections.unmodifiableList(key); } }, columns ); } @SuppressWarnings("unchecked") public <V> DataFrame<V> apply(final DataFrame<V> df, final Function<?, ?> function) { if (df.isEmpty()) { return df; } final List<List<V>> grouped = new ArrayList<>(); final List<Object> names = new ArrayList<>(df.columns()); final List<Object> newcols = new ArrayList<>(); final List<Object> index = new ArrayList<>(); // construct new row index if (function instanceof Aggregate && !groups.isEmpty()) { for (final Object key : groups.keySet()) { index.add(key); } } // add key columns for (final int c : columns) { if (function instanceof Aggregate && !groups.isEmpty()) { final List<V> column = new ArrayList<>(); for (final Map.Entry<Object, SparseBitSet> entry : groups.entrySet()) { final SparseBitSet rows = entry.getValue(); final int r = rows.nextSetBit(0); column.add(df.get(r, c)); } grouped.add(column); newcols.add(names.get(c)); } else { grouped.add(df.col(c)); newcols.add(names.get(c)); } } // add aggregated data columns for (int c = 0; c < df.size(); c++) { if (!columns.contains(c)) { final List<V> column = new ArrayList<>(); if (groups.isEmpty()) { try { if (function instanceof Aggregate) { column.add((V)Aggregate.class.cast(function).apply(df.col(c))); } else { for (int r = 0; r < df.length(); r++) { column.add((V)Function.class.cast(function).apply(df.get(r, c))); } } } catch (final ClassCastException ignored) { } if (function instanceof CumulativeFunction) { CumulativeFunction.class.cast(function).reset(); } } else { for (final Map.Entry<Object, SparseBitSet> entry : groups.entrySet()) { final SparseBitSet rows = entry.getValue(); try { if (function instanceof Aggregate) { final List<V> values = new ArrayList<>(rows.cardinality()); for (int r = rows.nextSetBit(0); r >= 0; r = rows.nextSetBit(r + 1)) { values.add(df.get(r, c)); } column.add((V)Aggregate.class.cast(function).apply(values)); } else { for (int r = rows.nextSetBit(0); r >= 0; r = rows.nextSetBit(r + 1)) { column.add((V)Function.class.cast(function).apply(df.get(r, c))); } } } catch (final ClassCastException ignored) { } if (function instanceof CumulativeFunction) { CumulativeFunction.class.cast(function).reset(); } } } if (!column.isEmpty()) { grouped.add(column); newcols.add(names.get(c)); } } } if (newcols.size() <= columns.size()) { throw new IllegalArgumentException( "no results for aggregate function " + function.getClass().getSimpleName() ); } return new DataFrame<>(index, newcols, grouped); } public Set<Object> keys() { return groups.keySet(); } public Set<Integer> columns() { return columns; } @Override public Iterator<Map.Entry<Object, SparseBitSet>> iterator() { return groups.entrySet().iterator(); } }