/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.transforms.join;
import java.util.ArrayList;
import java.util.List;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.join.CoGbkResult.CoGbkResultCoder;
import org.apache.beam.sdk.transforms.join.KeyedPCollectionTuple.TaggedKeyedPCollection;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
/**
* A {@link PTransform} that performs a {@link CoGroupByKey} on a tuple
* of tables. A {@link CoGroupByKey} groups results from all
* tables by like keys into {@link CoGbkResult}s,
* from which the results for any specific table can be accessed by the
* {@link org.apache.beam.sdk.values.TupleTag}
* supplied with the initial table.
*
* <p>Example of performing a {@link CoGroupByKey} followed by a
* {@link ParDo} that consumes
* the results:
* <pre>{@code
* PCollection<KV<K, V1>> pt1 = ...;
* PCollection<KV<K, V2>> pt2 = ...;
*
* final TupleTag<V1> t1 = new TupleTag<>();
* final TupleTag<V2> t2 = new TupleTag<>();
* PCollection<KV<K, CoGbkResult>> coGbkResultCollection =
* KeyedPCollectionTuple.of(t1, pt1)
* .and(t2, pt2)
* .apply(CoGroupByKey.<K>create());
*
* PCollection<T> finalResultCollection =
* coGbkResultCollection.apply(ParDo.of(
* new DoFn<KV<K, CoGbkResult>, T>() {
* {@literal @}ProcessElement
* public void processElement(ProcessContext c) {
* KV<K, CoGbkResult> e = c.element();
* Iterable<V1> pt1Vals = e.getValue().getAll(t1);
* V2 pt2Val = e.getValue().getOnly(t2);
* ... Do Something ....
* c.output(...some T...);
* }
* }));
* }</pre>
*
* @param <K> the type of the keys in the input and output
* {@code PCollection}s
*/
public class CoGroupByKey<K> extends
PTransform<KeyedPCollectionTuple<K>,
PCollection<KV<K, CoGbkResult>>> {
/**
* Returns a {@code CoGroupByKey<K>} {@code PTransform}.
*
* @param <K> the type of the keys in the input and output
* {@code PCollection}s
*/
public static <K> CoGroupByKey<K> create() {
return new CoGroupByKey<>();
}
private CoGroupByKey() { }
@Override
public PCollection<KV<K, CoGbkResult>> expand(
KeyedPCollectionTuple<K> input) {
if (input.isEmpty()) {
throw new IllegalArgumentException(
"must have at least one input to a KeyedPCollections");
}
// First build the union coder.
// TODO: Look at better integration of union types with the
// schema specified in the input.
List<Coder<?>> codersList = new ArrayList<>();
for (TaggedKeyedPCollection<K, ?> entry : input.getKeyedCollections()) {
codersList.add(getValueCoder(entry.pCollection));
}
UnionCoder unionCoder = UnionCoder.of(codersList);
Coder<K> keyCoder = input.getKeyCoder();
KvCoder<K, RawUnionValue> kVCoder =
KvCoder.of(keyCoder, unionCoder);
PCollectionList<KV<K, RawUnionValue>> unionTables =
PCollectionList.empty(input.getPipeline());
// TODO: Use the schema to order the indices rather than depending
// on the fact that the schema ordering is identical to the ordering from
// input.getJoinCollections().
int index = -1;
for (TaggedKeyedPCollection<K, ?> entry : input.getKeyedCollections()) {
index++;
PCollection<KV<K, RawUnionValue>> unionTable =
makeUnionTable(index, entry.pCollection, kVCoder);
unionTables = unionTables.and(unionTable);
}
PCollection<KV<K, RawUnionValue>> flattenedTable =
unionTables.apply(Flatten.<KV<K, RawUnionValue>>pCollections());
PCollection<KV<K, Iterable<RawUnionValue>>> groupedTable =
flattenedTable.apply(GroupByKey.<K, RawUnionValue>create());
CoGbkResultSchema tupleTags = input.getCoGbkResultSchema();
PCollection<KV<K, CoGbkResult>> result = groupedTable.apply("ConstructCoGbkResultFn",
ParDo.of(new ConstructCoGbkResultFn<K>(tupleTags)));
result.setCoder(KvCoder.of(keyCoder,
CoGbkResultCoder.of(tupleTags, unionCoder)));
return result;
}
//////////////////////////////////////////////////////////////////////////////
/**
* Returns the value coder for the given PCollection. Assumes that the value
* coder is an instance of {@code KvCoder<K, V>}.
*/
private <V> Coder<V> getValueCoder(PCollection<KV<K, V>> pCollection) {
// Assumes that the PCollection uses a KvCoder.
Coder<?> entryCoder = pCollection.getCoder();
if (!(entryCoder instanceof KvCoder<?, ?>)) {
throw new IllegalArgumentException("PCollection does not use a KvCoder");
}
@SuppressWarnings("unchecked")
KvCoder<K, V> coder = (KvCoder<K, V>) entryCoder;
return coder.getValueCoder();
}
/**
* Returns a UnionTable for the given input PCollection, using the given
* union index and the given unionTableEncoder.
*/
private <V> PCollection<KV<K, RawUnionValue>> makeUnionTable(
final int index,
PCollection<KV<K, V>> pCollection,
KvCoder<K, RawUnionValue> unionTableEncoder) {
return pCollection.apply("MakeUnionTable" + index,
ParDo.of(new ConstructUnionTableFn<K, V>(index))).setCoder(unionTableEncoder);
}
/**
* A DoFn to construct a UnionTable (i.e., a
* {@code PCollection<KV<K, RawUnionValue>>} from a
* {@code PCollection<KV<K, V>>}.
*/
private static class ConstructUnionTableFn<K, V> extends
DoFn<KV<K, V>, KV<K, RawUnionValue>> {
private final int index;
public ConstructUnionTableFn(int index) {
this.index = index;
}
@ProcessElement
public void processElement(ProcessContext c) {
KV<K, ?> e = c.element();
c.output(KV.of(e.getKey(), new RawUnionValue(index, e.getValue())));
}
}
/**
* A DoFn to construct a CoGbkResult from an input grouped union
* table.
*/
private static class ConstructCoGbkResultFn<K>
extends DoFn<KV<K, Iterable<RawUnionValue>>,
KV<K, CoGbkResult>> {
private final CoGbkResultSchema schema;
public ConstructCoGbkResultFn(CoGbkResultSchema schema) {
this.schema = schema;
}
@ProcessElement
public void processElement(ProcessContext c) {
KV<K, Iterable<RawUnionValue>> e = c.element();
c.output(KV.of(e.getKey(), new CoGbkResult(schema, e.getValue())));
}
}
}