/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.transforms; import static com.google.common.base.Preconditions.checkArgument; import java.util.ArrayList; import java.util.List; import java.util.Random; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; /** * {@code PTransform}s for taking samples of the elements in a * {@code PCollection}, or samples of the values associated with each * key in a {@code PCollection} of {@code KV}s. * * <p>{@link #combineFn} can also be used manually, in combination with state and with the * {@link Combine} transform. */ public class Sample { /** Returns a {@link CombineFn} that computes a fixed-sized sample of its inputs. */ public static <T> CombineFn<T, ?, Iterable<T>> combineFn(int sampleSize) { return new FixedSizedSampleFn<>(sampleSize); } /** * {@code Sample#any(long)} takes a {@code PCollection<T>} and a limit, and * produces a new {@code PCollection<T>} containing up to limit * elements of the input {@code PCollection}. * * <p>If limit is greater than or equal to the size of the input * {@code PCollection}, then all the input's elements will be selected. * * <p>All of the elements of the output {@code PCollection} should fit into * main memory of a single worker machine. This operation does not * run in parallel. * * <p>Example of use: * <pre> {@code * PCollection<String> input = ...; * PCollection<String> output = input.apply(Sample.<String>any(100)); * } </pre> * * @param <T> the type of the elements of the input and output * {@code PCollection}s * @param limit the number of elements to take from the input */ public static <T> PTransform<PCollection<T>, PCollection<T>> any(long limit) { return new Any<>(limit); } /** * Returns a {@code PTransform} that takes a {@code PCollection<T>}, selects {@code sampleSize} * elements, uniformly at random, and returns a {@code PCollection<Iterable<T>>} containing the * selected elements. If the input {@code PCollection} has fewer than {@code sampleSize} elements, * then the output {@code Iterable<T>} will be all the input's elements. * * <p>All of the elements of the output {@code PCollection} should fit into * main memory of a single worker machine. This operation does not * run in parallel. * * <p>Example of use: * * <pre>{@code * PCollection<String> pc = ...; * PCollection<Iterable<String>> sampleOfSize10 = * pc.apply(Sample.fixedSizeGlobally(10)); * } * </pre> * * @param sampleSize the number of elements to select; must be {@code >= 0} * @param <T> the type of the elements */ public static <T> PTransform<PCollection<T>, PCollection<Iterable<T>>> fixedSizeGlobally( int sampleSize) { return new FixedSizeGlobally<>(sampleSize); } /** * Returns a {@code PTransform} that takes an input {@code PCollection<KV<K, V>>} and returns a * {@code PCollection<KV<K, Iterable<V>>>} that contains an output element mapping each distinct * key in the input {@code PCollection} to a sample of {@code sampleSize} values associated with * that key in the input {@code PCollection}, taken uniformly at random. If a key in the input * {@code PCollection} has fewer than {@code sampleSize} values associated with it, then the * output {@code Iterable<V>} associated with that key will be all the values associated with that * key in the input {@code PCollection}. * * <p>Example of use: * * <pre>{@code * PCollection<KV<String, Integer>> pc = ...; * PCollection<KV<String, Iterable<Integer>>> sampleOfSize10PerKey = * pc.apply(Sample.<String, Integer>fixedSizePerKey()); * } * </pre> * * @param sampleSize the number of values to select for each distinct key; must be {@code >= 0} * @param <K> the type of the keys * @param <V> the type of the values */ public static <K, V> PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>> fixedSizePerKey( int sampleSize) { return new FixedSizePerKey<>(sampleSize); } ///////////////////////////////////////////////////////////////////////////// /** Implementation of {@link #any(long)}. */ private static class Any<T> extends PTransform<PCollection<T>, PCollection<T>> { private final long limit; /** * Constructs a {@code SampleAny<T>} PTransform that, when applied, * produces a new PCollection containing up to {@code limit} * elements of its input {@code PCollection}. */ private Any(long limit) { checkArgument(limit >= 0, "Expected non-negative limit, received %s.", limit); this.limit = limit; } @Override public PCollection<T> expand(PCollection<T> in) { PCollectionView<Iterable<T>> iterableView = in.apply(View.<T>asIterable()); return in.getPipeline() .apply(Create.of((Void) null).withCoder(VoidCoder.of())) .apply(ParDo.of(new SampleAnyDoFn<>(limit, iterableView)).withSideInputs(iterableView)) .setCoder(in.getCoder()); } @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); builder.add(DisplayData.item("sampleSize", limit) .withLabel("Sample Size")); } } /** Implementation of {@link #fixedSizeGlobally(int)}. */ private static class FixedSizeGlobally<T> extends PTransform<PCollection<T>, PCollection<Iterable<T>>> { private final int sampleSize; private FixedSizeGlobally(int sampleSize) { this.sampleSize = sampleSize; } @Override public PCollection<Iterable<T>> expand(PCollection<T> input) { return input.apply(Combine.globally(new FixedSizedSampleFn<T>(sampleSize))); } @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); builder.add(DisplayData.item("sampleSize", sampleSize) .withLabel("Sample Size")); } } /** Implementation of {@link #fixedSizeGlobally(int)}. */ private static class FixedSizePerKey<K, V> extends PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>> { private final int sampleSize; private FixedSizePerKey(int sampleSize) { this.sampleSize = sampleSize; } @Override public PCollection<KV<K, Iterable<V>>> expand(PCollection<KV<K, V>> input) { return input.apply(Combine.<K, V, Iterable<V>>perKey(new FixedSizedSampleFn<V>(sampleSize))); } @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); builder.add(DisplayData.item("sampleSize", sampleSize) .withLabel("Sample Size")); } } /** * A {@link DoFn} that returns up to limit elements from the side input PCollection. */ private static class SampleAnyDoFn<T> extends DoFn<Void, T> { long limit; final PCollectionView<Iterable<T>> iterableView; public SampleAnyDoFn(long limit, PCollectionView<Iterable<T>> iterableView) { this.limit = limit; this.iterableView = iterableView; } @ProcessElement public void processElement(ProcessContext c) { for (T i : c.sideInput(iterableView)) { if (limit-- <= 0) { break; } c.output(i); } } } /** * {@code CombineFn} that computes a fixed-size sample of a * collection of values. * * @param <T> the type of the elements */ public static class FixedSizedSampleFn<T> extends CombineFn<T, Top.BoundedHeap<KV<Integer, T>, SerializableComparator<KV<Integer, T>>>, Iterable<T>> { private final int sampleSize; private final Top.TopCombineFn<KV<Integer, T>, SerializableComparator<KV<Integer, T>>> topCombineFn; private final Random rand = new Random(); private FixedSizedSampleFn(int sampleSize) { if (sampleSize < 0) { throw new IllegalArgumentException("sample size must be >= 0"); } this.sampleSize = sampleSize; topCombineFn = new Top.TopCombineFn<KV<Integer, T>, SerializableComparator<KV<Integer, T>>>( sampleSize, new KV.OrderByKey<Integer, T>()); } @Override public Top.BoundedHeap<KV<Integer, T>, SerializableComparator<KV<Integer, T>>> createAccumulator() { return topCombineFn.createAccumulator(); } @Override public Top.BoundedHeap<KV<Integer, T>, SerializableComparator<KV<Integer, T>>> addInput( Top.BoundedHeap<KV<Integer, T>, SerializableComparator<KV<Integer, T>>> accumulator, T input) { accumulator.addInput(KV.of(rand.nextInt(), input)); return accumulator; } @Override public Top.BoundedHeap<KV<Integer, T>, SerializableComparator<KV<Integer, T>>> mergeAccumulators( Iterable<Top.BoundedHeap<KV<Integer, T>, SerializableComparator<KV<Integer, T>>>> accumulators) { return topCombineFn.mergeAccumulators(accumulators); } @Override public Iterable<T> extractOutput( Top.BoundedHeap<KV<Integer, T>, SerializableComparator<KV<Integer, T>>> accumulator) { List<T> out = new ArrayList<>(); for (KV<Integer, T> element : accumulator.extractOutput()) { out.add(element.getValue()); } return out; } @Override public Coder<Top.BoundedHeap<KV<Integer, T>, SerializableComparator<KV<Integer, T>>>> getAccumulatorCoder(CoderRegistry registry, Coder<T> inputCoder) { return topCombineFn.getAccumulatorCoder( registry, KvCoder.of(BigEndianIntegerCoder.of(), inputCoder)); } @Override public Coder<Iterable<T>> getDefaultOutputCoder( CoderRegistry registry, Coder<T> inputCoder) { return IterableCoder.of(inputCoder); } @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); builder.add(DisplayData.item("sampleSize", sampleSize) .withLabel("Sample Size")); } } }