/* * Joinery -- Data frames for Java * Copyright (c) 2014, 2015 IBM Corp. * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package joinery.impl; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; import org.apache.commons.math3.stat.correlation.StorelessCovariance; import org.apache.commons.math3.stat.descriptive.StatisticalSummary; import org.apache.commons.math3.stat.descriptive.StorelessUnivariateStatistic; import org.apache.commons.math3.stat.descriptive.SummaryStatistics; import org.apache.commons.math3.stat.descriptive.UnivariateStatistic; import joinery.DataFrame; import joinery.DataFrame.Aggregate; public class Aggregation { public static class Count<V> implements Aggregate<V, Number> { @Override public Number apply(final List<V> values) { return new Integer(values.size()); } } public static class Unique<V> implements Aggregate<V, V> { @Override public V apply(final List<V> values) { final Set<V> unique = new HashSet<>(values); if (unique.size() > 1) { throw new IllegalArgumentException("values not unique: " + unique); } return values.get(0); } } public static class Collapse<V> implements Aggregate<V, String> { private final String delimiter; public Collapse() { this(","); } public Collapse(final String delimiter) { this.delimiter = delimiter; } @Override public String apply(final List<V> values) { final Set<V> seen = new HashSet<>(); final StringBuilder sb = new StringBuilder(); for (final V value : values) { if (!seen.contains(value)) { if (sb.length() > 0) { sb.append(delimiter); } sb.append(String.valueOf(value)); seen.add(value); } } return sb.toString(); } } private static abstract class AbstractStorelessStatistic<V> implements Aggregate<V, Number> { protected final StorelessUnivariateStatistic stat; protected AbstractStorelessStatistic(final StorelessUnivariateStatistic stat) { this.stat = stat; } @Override public Number apply(final List<V> values) { stat.clear(); for (Object value : values) { if (value != null) { if (value instanceof Boolean) { value = Boolean.class.cast(value) ? 1 : 0; } stat.increment(Number.class.cast(value).doubleValue()); } } return stat.getResult(); } } public static class Sum<V> extends AbstractStorelessStatistic<V> { public Sum() { super(new org.apache.commons.math3.stat.descriptive.summary.Sum()); } } public static class Product<V> extends AbstractStorelessStatistic<V> { public Product() { super(new org.apache.commons.math3.stat.descriptive.summary.Product()); } } public static class Mean<V> extends AbstractStorelessStatistic<V> { public Mean() { super(new org.apache.commons.math3.stat.descriptive.moment.Mean()); } } public static class StdDev<V> extends AbstractStorelessStatistic<V> { public StdDev() { super(new org.apache.commons.math3.stat.descriptive.moment.StandardDeviation()); } } public static class Variance<V> extends AbstractStorelessStatistic<V> { public Variance() { super(new org.apache.commons.math3.stat.descriptive.moment.Variance()); } } public static class Skew<V> extends AbstractStorelessStatistic<V> { public Skew() { super(new org.apache.commons.math3.stat.descriptive.moment.Skewness()); } } public static class Kurtosis<V> extends AbstractStorelessStatistic<V> { public Kurtosis() { super(new org.apache.commons.math3.stat.descriptive.moment.Kurtosis()); } } public static class Min<V> extends AbstractStorelessStatistic<V> { public Min() { super(new org.apache.commons.math3.stat.descriptive.rank.Min()); } } public static class Max<V> extends AbstractStorelessStatistic<V> { public Max() { super(new org.apache.commons.math3.stat.descriptive.rank.Max()); } } private static abstract class AbstractStatistic<V> implements Aggregate<V, Number> { protected final UnivariateStatistic stat; protected AbstractStatistic(final UnivariateStatistic stat) { this.stat = stat; } @Override public Number apply(final List<V> values) { int count = 0; final double[] vals = new double[values.size()]; for (int i = 0; i < vals.length; i++) { final V val = values.get(i); if (val != null) { vals[count++] = Number.class.cast(val).doubleValue(); } } return stat.evaluate(vals, 0, count); } } public static class Median<V> extends AbstractStatistic<V> { public Median() { super(new org.apache.commons.math3.stat.descriptive.rank.Median()); } } public static class Percentile<V> extends AbstractStatistic<V> { public Percentile(final double quantile) { super(new org.apache.commons.math3.stat.descriptive.rank.Percentile(quantile)); } } public static class Describe<V> implements Aggregate<V, StatisticalSummary> { private final SummaryStatistics stat = new SummaryStatistics(); @Override public StatisticalSummary apply(final List<V> values) { stat.clear(); for (Object value : values) { if (value != null) { if (value instanceof Boolean) { value = Boolean.class.cast(value) ? 1 : 0; } stat.addValue(Number.class.cast(value).doubleValue()); } } return stat.getSummary(); } } private static final Object name(final DataFrame<?> df, final Object row, final Object stat) { // df index size > 1 only happens if the aggregate describes a grouped data frame return df.index().size() > 1 ? Arrays.asList(row, stat) : stat; } @SuppressWarnings("unchecked") public static <V> DataFrame<V> describe(final DataFrame<V> df) { final DataFrame<V> desc = new DataFrame<>(); for (final Object col : df.columns()) { for (final Object row : df.index()) { final V value = df.get(row, col); if (value instanceof StatisticalSummary) { if (!desc.columns().contains(col)) { desc.add(col); if (desc.isEmpty()) { for (final Object r : df.index()) { for (final Object stat : Arrays.asList("count", "mean", "std", "var", "max", "min")) { final Object name = name(df, r, stat); desc.append(name, Collections.<V>emptyList()); } } } } final StatisticalSummary summary = StatisticalSummary.class.cast(value); desc.set(name(df, row, "count"), col, (V)new Double(summary.getN())); desc.set(name(df, row, "mean"), col, (V)new Double(summary.getMean())); desc.set(name(df, row, "std"), col, (V)new Double(summary.getStandardDeviation())); desc.set(name(df, row, "var"), col, (V)new Double(summary.getVariance())); desc.set(name(df, row, "max"), col, (V)new Double(summary.getMax())); desc.set(name(df, row, "min"), col, (V)new Double(summary.getMin())); } } } return desc; } public static <V> DataFrame<Number> cov(final DataFrame<V> df) { DataFrame<Number> num = df.numeric(); StorelessCovariance cov = new StorelessCovariance(num.size()); // row-wise copy to double array and increment double[] data = new double[num.size()]; for (List<Number> row : num) { for (int i = 0; i < row.size(); i++) { data[i] = row.get(i).doubleValue(); } cov.increment(data); } // row-wise copy results into new data frame double[][] result = cov.getData(); DataFrame<Number> r = new DataFrame<>(num.columns()); List<Number> row = new ArrayList<>(num.size()); for (int i = 0; i < result.length; i++) { row.clear(); for (int j = 0; j < result[i].length; j++) { row.add(result[i][j]); } r.append(row); } return r; } }