/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.transforms; import com.google.common.base.MoreObjects; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.Objects; import org.apache.beam.sdk.coders.AtomicCoder; import org.apache.beam.sdk.coders.BigEndianLongCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.DoubleCoder; import org.apache.beam.sdk.transforms.Combine.AccumulatingCombineFn.Accumulator; /** * {@code PTransform}s for computing the arithmetic mean * (a.k.a. average) of the elements in a {@code PCollection}, or the * mean of the values associated with each key in a * {@code PCollection} of {@code KV}s. * * <p>Example 1: get the mean of a {@code PCollection} of {@code Long}s. * <pre> {@code * PCollection<Long> input = ...; * PCollection<Double> mean = input.apply(Mean.<Long>globally()); * } </pre> * * <p>Example 2: calculate the mean of the {@code Integer}s * associated with each unique key (which is of type {@code String}). * <pre> {@code * PCollection<KV<String, Integer>> input = ...; * PCollection<KV<String, Double>> meanPerKey = * input.apply(Mean.<String, Integer>perKey()); * } </pre> */ public class Mean { private Mean() { } // Namespace only /** * Returns a {@code PTransform} that takes an input * {@code PCollection<NumT>} and returns a * {@code PCollection<Double>} whose contents is the mean of the * input {@code PCollection}'s elements, or * {@code 0} if there are no elements. * * @param <NumT> the type of the {@code Number}s being combined */ public static <NumT extends Number> Combine.Globally<NumT, Double> globally() { return Combine.<NumT, Double>globally(Mean.<NumT>of()); } /** * Returns a {@code PTransform} that takes an input * {@code PCollection<KV<K, N>>} and returns a * {@code PCollection<KV<K, Double>>} that contains an output * element mapping each distinct key in the input * {@code PCollection} to the mean of the values associated with * that key in the input {@code PCollection}. * * <p>See {@link Combine.PerKey} for how this affects timestamps and bucketing. * * @param <K> the type of the keys * @param <NumT> the type of the {@code Number}s being combined */ public static <K, NumT extends Number> Combine.PerKey<K, NumT, Double> perKey() { return Combine.<K, NumT, Double>perKey(Mean.<NumT>of()); } /** * A {@code Combine.CombineFn} that computes the arithmetic mean * (a.k.a. average) of an {@code Iterable} of numbers of type * {@code N}, useful as an argument to {@link Combine#globally} or * {@link Combine#perKey}. * * <p>Returns {@code Double.NaN} if combining zero elements. * * @param <NumT> the type of the {@code Number}s being combined */ public static <NumT extends Number> Combine.AccumulatingCombineFn<NumT, CountSum<NumT>, Double> of() { return new MeanFn<>(); } ///////////////////////////////////////////////////////////////////////////// private static class MeanFn<NumT extends Number> extends Combine.AccumulatingCombineFn<NumT, CountSum<NumT>, Double> { /** * Constructs a combining function that computes the mean over * a collection of values of type {@code N}. */ @Override public CountSum<NumT> createAccumulator() { return new CountSum<>(); } @Override public Coder<CountSum<NumT>> getAccumulatorCoder( CoderRegistry registry, Coder<NumT> inputCoder) { return new CountSumCoder<>(); } } /** * Accumulator class for {@link MeanFn}. */ static class CountSum<NumT extends Number> implements Accumulator<NumT, CountSum<NumT>, Double> { long count = 0; double sum = 0.0; public CountSum() { this(0, 0); } public CountSum(long count, double sum) { this.count = count; this.sum = sum; } @Override public void addInput(NumT element) { count++; sum += element.doubleValue(); } @Override public void mergeAccumulator(CountSum<NumT> accumulator) { count += accumulator.count; sum += accumulator.sum; } @Override public Double extractOutput() { return count == 0 ? Double.NaN : sum / count; } @Override public boolean equals(Object other) { if (!(other instanceof CountSum)) { return false; } @SuppressWarnings("unchecked") CountSum<?> otherCountSum = (CountSum<?>) other; return (count == otherCountSum.count) && (sum == otherCountSum.sum); } @Override public int hashCode() { return Objects.hash(count, sum); } @Override public String toString() { return MoreObjects.toStringHelper(this) .add("count", count) .add("sum", sum) .toString(); } } static class CountSumCoder<NumT extends Number> extends AtomicCoder<CountSum<NumT>> { private static final Coder<Long> LONG_CODER = BigEndianLongCoder.of(); private static final Coder<Double> DOUBLE_CODER = DoubleCoder.of(); @Override public void encode(CountSum<NumT> value, OutputStream outStream) throws CoderException, IOException { LONG_CODER.encode(value.count, outStream); DOUBLE_CODER.encode(value.sum, outStream); } @Override public CountSum<NumT> decode(InputStream inStream) throws CoderException, IOException { return new CountSum<>( LONG_CODER.decode(inStream), DOUBLE_CODER.decode(inStream)); } @Override public void verifyDeterministic() throws NonDeterministicException { LONG_CODER.verifyDeterministic(); DOUBLE_CODER.verifyDeterministic(); } } }