/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.transforms;
import java.io.Serializable;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
/**
* {@code Partition} takes a {@code PCollection<T>} and a
* {@code PartitionFn}, uses the {@code PartitionFn} to split the
* elements of the input {@code PCollection} into {@code N} partitions, and
* returns a {@code PCollectionList<T>} that bundles {@code N}
* {@code PCollection<T>}s containing the split elements.
*
* <p>Example of use:
* <pre> {@code
* PCollection<Student> students = ...;
* // Split students up into 10 partitions, by percentile:
* PCollectionList<Student> studentsByPercentile =
* students.apply(Partition.of(10, new PartitionFn<Student>() {
* public int partitionFor(Student student, int numPartitions) {
* return student.getPercentile() // 0..99
* * numPartitions / 100;
* }}))
* for (int i = 0; i < 10; i++) {
* PCollection<Student> partition = studentsByPercentile.get(i);
* ...
* }
* } </pre>
*
* <p>By default, the {@code Coder} of each of the
* {@code PCollection}s in the output {@code PCollectionList} is the
* same as the {@code Coder} of the input {@code PCollection}.
*
* <p>Each output element has the same timestamp and is in the same windows
* as its corresponding input element, and each output {@code PCollection}
* has the same
* {@link org.apache.beam.sdk.transforms.windowing.WindowFn}
* associated with it as the input.
*
* @param <T> the type of the elements of the input and output
* {@code PCollection}s
*/
public class Partition<T> extends PTransform<PCollection<T>, PCollectionList<T>> {
/**
* A function object that chooses an output partition for an element.
*
* @param <T> the type of the elements being partitioned
*/
public interface PartitionFn<T> extends Serializable {
/**
* Chooses the partition into which to put the given element.
*
* @param elem the element to be partitioned
* @param numPartitions the total number of partitions ({@code >= 1})
* @return index of the selected partition (in the range
* {@code [0..numPartitions-1]})
*/
int partitionFor(T elem, int numPartitions);
}
/**
* Returns a new {@code Partition} {@code PTransform} that divides
* its input {@code PCollection} into the given number of partitions,
* using the given partitioning function.
*
* @param numPartitions the number of partitions to divide the input
* {@code PCollection} into
* @param partitionFn the function to invoke on each element to
* choose its output partition
* @throws IllegalArgumentException if {@code numPartitions <= 0}
*/
public static <T> Partition<T> of(
int numPartitions, PartitionFn<? super T> partitionFn) {
return new Partition<>(new PartitionDoFn<T>(numPartitions, partitionFn));
}
/////////////////////////////////////////////////////////////////////////////
@Override
public PCollectionList<T> expand(PCollection<T> in) {
final TupleTagList outputTags = partitionDoFn.getOutputTags();
PCollectionTuple outputs = in.apply(
ParDo
.of(partitionDoFn)
.withOutputTags(new TupleTag<Void>(){}, outputTags));
PCollectionList<T> pcs = PCollectionList.empty(in.getPipeline());
Coder<T> coder = in.getCoder();
for (TupleTag<?> outputTag : outputTags.getAll()) {
// All the tuple tags are actually TupleTag<T>
// And all the collections are actually PCollection<T>
@SuppressWarnings("unchecked")
TupleTag<T> typedOutputTag = (TupleTag<T>) outputTag;
pcs = pcs.and(outputs.get(typedOutputTag).setCoder(coder));
}
return pcs;
}
@Override
public void populateDisplayData(DisplayData.Builder builder) {
super.populateDisplayData(builder);
builder.include("partitionFn", partitionDoFn);
}
private final transient PartitionDoFn<T> partitionDoFn;
private Partition(PartitionDoFn<T> partitionDoFn) {
this.partitionDoFn = partitionDoFn;
}
private static class PartitionDoFn<X> extends DoFn<X, Void> {
private final int numPartitions;
private final PartitionFn<? super X> partitionFn;
private final TupleTagList outputTags;
/**
* Constructs a PartitionDoFn.
*
* @throws IllegalArgumentException if {@code numPartitions <= 0}
*/
public PartitionDoFn(int numPartitions, PartitionFn<? super X> partitionFn) {
if (numPartitions <= 0) {
throw new IllegalArgumentException("numPartitions must be > 0");
}
this.numPartitions = numPartitions;
this.partitionFn = partitionFn;
TupleTagList buildOutputTags = TupleTagList.empty();
for (int partition = 0; partition < numPartitions; partition++) {
buildOutputTags = buildOutputTags.and(new TupleTag<X>());
}
outputTags = buildOutputTags;
}
public TupleTagList getOutputTags() {
return outputTags;
}
@ProcessElement
public void processElement(ProcessContext c) {
X input = c.element();
int partition = partitionFn.partitionFor(input, numPartitions);
if (0 <= partition && partition < numPartitions) {
@SuppressWarnings("unchecked")
TupleTag<X> typedTag = (TupleTag<X>) outputTags.get(partition);
c.output(typedTag, input);
} else {
throw new IndexOutOfBoundsException(
"Partition function returned out of bounds index: "
+ partition + " not in [0.." + numPartitions + ")");
}
}
@Override
public void populateDisplayData(DisplayData.Builder builder) {
super.populateDisplayData(builder);
builder
.add(DisplayData.item("numPartitions", numPartitions)
.withLabel("Partition Count"))
.add(DisplayData.item("partitionFn", partitionFn.getClass())
.withLabel("Partition Function"));
}
}
}