/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java.summarize.aggregation;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.java.summarize.NumericColumnSummary;
import static org.apache.flink.api.java.summarize.aggregation.CompensatedSum.ZERO;
/**
* Generic aggregator for all numeric types creates a summary of a column of numbers.
*
* Uses the Kahan summation algorithm to avoid numeric instability when computing variance.
* The algorithm is described in: "Scalable and Numerically Stable Descriptive Statistics in SystemML",
* Tian et al, International Conference on Data Engineering 2012
*
* Implementation that couldn't be generic for all numbers was pushed to subclasses.
* For example, there isn't a generic way to calculate min, max, sum, isNan, isInfinite
* for all numeric types so subclasses must implement these.
*
* @param <T> numeric type to aggregrate and create a summary, e.g. Integer, DoubleValue
*/
@Internal
public abstract class NumericSummaryAggregator<T extends Number> implements Aggregator<T,NumericColumnSummary<T>> {
private static final long serialVersionUID = 1L;
private long nonMissingCount = 0L; // count of elements that are NOT null, NaN, or Infinite
private long nullCount = 0L;
private long nanCount = 0L;
private long infinityCount = 0L;
// these fields are initialized by type specific subclasses
private Aggregator<T,T> min = initMin();
private Aggregator<T,T> max = initMax();
private Aggregator<T,T> sum = initSum();
private CompensatedSum mean = ZERO;
/**
* Sum of squares of differences from the current mean (used to calculate variance).
*
* The algorithm is described in: "Scalable and Numerically Stable Descriptive Statistics in SystemML",
* Tian et al, International Conference on Data Engineering 2012
*/
private CompensatedSum m2 = ZERO;
/**
* Add a value to the current aggregation
*/
@Override
public void aggregate(T value) {
if (value == null) {
nullCount++;
}
else if (isNan(value)) {
nanCount++;
}
else if(isInfinite(value)) {
infinityCount++;
}
else {
nonMissingCount++;
min.aggregate(value);
max.aggregate(value);
sum.aggregate(value);
double doubleValue = value.doubleValue();
double delta = doubleValue - mean.value();
mean = mean.add(delta / nonMissingCount);
m2 = m2.add(delta * (doubleValue - mean.value()));
}
}
/**
* combine two aggregations
*/
@Override
public void combine(Aggregator<T, NumericColumnSummary<T>> otherSameType) {
NumericSummaryAggregator<T> other = (NumericSummaryAggregator<T>) otherSameType;
nullCount += other.nullCount;
nanCount += other.nanCount;
infinityCount += other.infinityCount;
if (nonMissingCount == 0) {
nonMissingCount = other.nonMissingCount;
min = other.min;
max = other.max;
sum = other.sum;
mean = other.mean;
m2 = other.m2;
}
else if (other.nonMissingCount != 0) {
long combinedCount = nonMissingCount + other.nonMissingCount;
min.combine(other.min);
max.combine(other.max);
sum.combine(other.sum);
double deltaMean = other.mean.value() - mean.value();
mean = mean.add(deltaMean * other.nonMissingCount / combinedCount);
m2 = m2.add(other.m2).add(deltaMean * deltaMean * nonMissingCount * other.nonMissingCount / combinedCount);
nonMissingCount = combinedCount;
}
}
@Override
public NumericColumnSummary<T> result() {
Double variance = null;
if(nonMissingCount > 1) {
variance = m2.value() / (nonMissingCount - 1);
}
return new NumericColumnSummary<T>(
nonMissingCount,
nullCount,
nanCount,
infinityCount,
// if nonMissingCount was zero some fields should be undefined
nonMissingCount == 0 ? null : min.result(),
nonMissingCount == 0 ? null : max.result(),
nonMissingCount == 0 ? null : sum.result(),
nonMissingCount == 0 ? null : mean.value(),
variance,
variance == null ? null : Math.sqrt(variance) // standard deviation
);
}
// there isn't a generic way to calculate min, max, sum, isNan, isInfinite for all numeric types
// so subclasses must implement these
protected abstract Aggregator<T,T> initMin();
protected abstract Aggregator<T,T> initMax();
protected abstract Aggregator<T,T> initSum();
protected abstract boolean isNan(T number);
protected abstract boolean isInfinite(T number);
}