/* * Joinery -- Data frames for Java * Copyright (c) 2014, 2015 IBM Corp. * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package joinery.impl; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; import joinery.DataFrame; import joinery.DataFrame.Aggregate; import joinery.DataFrame.KeyFunction; import joinery.impl.Aggregation.Unique; public class Pivoting { public static <V> DataFrame<V> pivot( final DataFrame<V> df, final Integer[] rows, final Integer[] cols, final Integer[] values) { final DataFrame<V> grouped = df.groupBy(rows); final Map<Object, DataFrame<V>> exploded = grouped.explode(); final Map<Integer, Unique<V>> aggregates = new LinkedHashMap<>(); for (final Map.Entry<Object, DataFrame<V>> entry : exploded.entrySet()) { exploded.put(entry.getKey(), entry.getValue().groupBy(cols)); } for (final int v : values) { aggregates.put(v, new Unique<V>()); } return pivot(exploded, aggregates, grouped.groups().columns()); } public static <I, O> DataFrame<O> pivot( final DataFrame<I> df, final KeyFunction<I> rows, final KeyFunction<I> cols, final Map<Integer, ? extends Aggregate<I,O>> values) { final DataFrame<I> grouped = df.groupBy(rows); final Map<Object, DataFrame<I>> exploded = grouped.explode(); for (final Map.Entry<Object, DataFrame<I>> entry : exploded.entrySet()) { exploded.put(entry.getKey(), entry.getValue().groupBy(cols)); } return pivot(exploded, values, grouped.groups().columns()); } @SuppressWarnings("unchecked") private static <I, O> DataFrame<O> pivot( final Map<Object, DataFrame<I>> grouped, final Map<Integer, ? extends Aggregate<I,O>> values, final Set<Integer> columns) { final Set<Object> pivotCols = new LinkedHashSet<>(); final Map<Object, Map<Object, List<I>>> pivotData = new LinkedHashMap<>(); final Map<Object, Aggregate<I, ?>> pivotFunctions = new LinkedHashMap<>(); final List<Object> colNames = new ArrayList<>(grouped.values().iterator().next().columns()); // allocate row -> column -> data maps for (final Map.Entry<Object, DataFrame<I>> rowEntry : grouped.entrySet()) { final Map<Object, List<I>> rowData = new LinkedHashMap<>(); for (final int c : columns) { final Object colName = colNames.get(c); rowData.put(colName, new ArrayList<I>()); pivotCols.add(colName); } for (final Object colKey : rowEntry.getValue().groups().keys()) { for (final int c : values.keySet()) { final Object colName = name(colKey, colNames.get(c), values); rowData.put(colName, new ArrayList<I>()); pivotCols.add(colName); pivotFunctions.put(colName, values.get(c)); } } pivotData.put(rowEntry.getKey(), rowData); } // collect data for row and column groups for (final Map.Entry<Object, DataFrame<I>> rowEntry : grouped.entrySet()) { final Object rowName = rowEntry.getKey(); final Map<Object, List<I>> rowData = pivotData.get(rowName); final Map<Object, DataFrame<I>> byCol = rowEntry.getValue().explode(); for (final Map.Entry<Object, DataFrame<I>> colEntry : byCol.entrySet()) { // add columns used as pivot rows for (final int c : columns) { final Object colName = colNames.get(c); final List<I> colData = rowData.get(colName); // optimization, only add first value // since the values are all the same (due to grouping) colData.add(colEntry.getValue().get(0, c)); } // add values for aggregation for (final int c : values.keySet()) { final Object colName = name(colEntry.getKey(), colNames.get(c), values); final List<I> colData = rowData.get(colName); colData.addAll(colEntry.getValue().col(c)); } } } // iterate over row, column pairs and apply aggregate functions final DataFrame<O> pivot = new DataFrame<>(pivotData.keySet(), pivotCols); for (final Object col : pivot.columns()) { for (final Object row : pivot.index()) { final List<I> data = pivotData.get(row).get(col); if (data != null) { final Aggregate<I, ?> func = pivotFunctions.get(col); if (func != null) { pivot.set(row, col, (O)func.apply(data)); } else { pivot.set(row, col, (O)data.get(0)); } } } } return pivot; } private static Object name(final Object key, final Object name, final Map<?, ?> values) { Object colName = key; // if multiple value columns are requested the // value column name must be added to the pivot column name if (values.size() > 1) { final List<Object> tmp = new ArrayList<>(); tmp.add(name); if (key instanceof List) { for (final Object col : List.class.cast(key)) { tmp.add(col); } } else { tmp.add(key); } colName = tmp; } return colName; } }