/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.runners.spark.translation; import com.google.common.base.Optional; import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.util.ByteArray; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.transforms.Reshuffle; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.spark.HashPartitioner; import org.apache.spark.Partitioner; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; /** * A set of group/combine functions to apply to Spark {@link org.apache.spark.rdd.RDD}s. */ public class GroupCombineFunctions { /** * An implementation of * {@link org.apache.beam.runners.core.GroupByKeyViaGroupByKeyOnly.GroupByKeyOnly} * for the Spark runner. */ public static <K, V> JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> groupByKeyOnly( JavaRDD<WindowedValue<KV<K, V>>> rdd, Coder<K> keyCoder, WindowedValueCoder<V> wvCoder) { // we use coders to convert objects in the PCollection to byte arrays, so they // can be transferred over the network for the shuffle. JavaPairRDD<ByteArray, byte[]> pairRDD = rdd .map(new ReifyTimestampsAndWindowsFunction<K, V>()) .map(WindowingHelpers.<KV<K, WindowedValue<V>>>unwindowFunction()) .mapToPair(TranslationUtils.<K, WindowedValue<V>>toPairFunction()) .mapToPair(CoderHelpers.toByteFunction(keyCoder, wvCoder)); // use a default parallelism HashPartitioner. Partitioner partitioner = new HashPartitioner(rdd.rdd().sparkContext().defaultParallelism()); // using mapPartitions allows to preserve the partitioner // and avoid unnecessary shuffle downstream. return pairRDD .groupByKey(partitioner) .mapPartitionsToPair( TranslationUtils.pairFunctionToPairFlatMapFunction( CoderHelpers.fromByteFunctionIterable(keyCoder, wvCoder)), true) .mapPartitions( TranslationUtils.<K, Iterable<WindowedValue<V>>>fromPairFlatMapFunction(), true) .mapPartitions( TranslationUtils.functionToFlatMapFunction( WindowingHelpers.<KV<K, Iterable<WindowedValue<V>>>>windowFunction()), true); } /** * Apply a composite {@link org.apache.beam.sdk.transforms.Combine.Globally} transformation. */ public static <InputT, AccumT> Optional<Iterable<WindowedValue<AccumT>>> combineGlobally( JavaRDD<WindowedValue<InputT>> rdd, final SparkGlobalCombineFn<InputT, AccumT, ?> sparkCombineFn, final Coder<InputT> iCoder, final Coder<AccumT> aCoder, final WindowingStrategy<?, ?> windowingStrategy) { // coders. final WindowedValue.FullWindowedValueCoder<InputT> wviCoder = WindowedValue.FullWindowedValueCoder.of(iCoder, windowingStrategy.getWindowFn().windowCoder()); final WindowedValue.FullWindowedValueCoder<AccumT> wvaCoder = WindowedValue.FullWindowedValueCoder.of(aCoder, windowingStrategy.getWindowFn().windowCoder()); final IterableCoder<WindowedValue<AccumT>> iterAccumCoder = IterableCoder.of(wvaCoder); // Use coders to convert objects in the PCollection to byte arrays, so they // can be transferred over the network for the shuffle. // for readability, we add comments with actual type next to byte[]. // to shorten line length, we use: //---- WV: WindowedValue //---- Iterable: Itr //---- AccumT: A //---- InputT: I JavaRDD<byte[]> inputRDDBytes = rdd.map(CoderHelpers.toByteFunction(wviCoder)); if (inputRDDBytes.isEmpty()) { return Optional.absent(); } /*Itr<WV<A>>*/ byte[] accumulatedBytes = inputRDDBytes.aggregate( CoderHelpers.toByteArray(sparkCombineFn.zeroValue(), iterAccumCoder), new Function2</*A*/ byte[], /*I*/ byte[], /*A*/ byte[]>() { @Override public /*Itr<WV<A>>*/ byte[] call(/*Itr<WV<A>>*/ byte[] ab, /*WV<I>*/ byte[] ib) throws Exception { Iterable<WindowedValue<AccumT>> a = CoderHelpers.fromByteArray(ab, iterAccumCoder); WindowedValue<InputT> i = CoderHelpers.fromByteArray(ib, wviCoder); return CoderHelpers.toByteArray(sparkCombineFn.seqOp(a, i), iterAccumCoder); } }, new Function2</*Itr<WV<A>>>*/ byte[], /*Itr<WV<A>>>*/ byte[], /*Itr<WV<A>>>*/ byte[]>() { @Override public /*Itr<WV<A>>>*/ byte[] call(/*Itr<WV<A>>>*/ byte[] a1b, /*Itr<WV<A>>>*/ byte[] a2b) throws Exception { Iterable<WindowedValue<AccumT>> a1 = CoderHelpers.fromByteArray(a1b, iterAccumCoder); Iterable<WindowedValue<AccumT>> a2 = CoderHelpers.fromByteArray(a2b, iterAccumCoder); Iterable<WindowedValue<AccumT>> merged = sparkCombineFn.combOp(a1, a2); return CoderHelpers.toByteArray(merged, iterAccumCoder); } } ); return Optional.of(CoderHelpers.fromByteArray(accumulatedBytes, iterAccumCoder)); } /** * Apply a composite {@link org.apache.beam.sdk.transforms.Combine.PerKey} transformation. * <p> * This aggregation will apply Beam's {@link org.apache.beam.sdk.transforms.Combine.CombineFn} * via Spark's {@link JavaPairRDD#combineByKey(Function, Function2, Function2)} aggregation. * </p> * For streaming, this will be called from within a serialized context * (DStream's transform callback), so passed arguments need to be Serializable. */ public static <K, InputT, AccumT> JavaPairRDD<K, Iterable<WindowedValue<KV<K, AccumT>>>> combinePerKey( JavaRDD<WindowedValue<KV<K, InputT>>> rdd, final SparkKeyedCombineFn<K, InputT, AccumT, ?> sparkCombineFn, final Coder<K> keyCoder, final Coder<InputT> iCoder, final Coder<AccumT> aCoder, final WindowingStrategy<?, ?> windowingStrategy) { // coders. final WindowedValue.FullWindowedValueCoder<KV<K, InputT>> wkviCoder = WindowedValue.FullWindowedValueCoder.of(KvCoder.of(keyCoder, iCoder), windowingStrategy.getWindowFn().windowCoder()); final WindowedValue.FullWindowedValueCoder<KV<K, AccumT>> wkvaCoder = WindowedValue.FullWindowedValueCoder.of(KvCoder.of(keyCoder, aCoder), windowingStrategy.getWindowFn().windowCoder()); final IterableCoder<WindowedValue<KV<K, AccumT>>> iterAccumCoder = IterableCoder.of(wkvaCoder); // We need to duplicate K as both the key of the JavaPairRDD as well as inside the value, // since the functions passed to combineByKey don't receive the associated key of each // value, and we need to map back into methods in Combine.KeyedCombineFn, which each // require the key in addition to the InputT's and AccumT's being merged/accumulated. // Once Spark provides a way to include keys in the arguments of combine/merge functions, // we won't need to duplicate the keys anymore. // Key has to bw windowed in order to group by window as well. JavaPairRDD<K, WindowedValue<KV<K, InputT>>> inRddDuplicatedKeyPair = rdd.mapToPair(TranslationUtils.<K, InputT>toPairByKeyInWindowedValue()); // Use coders to convert objects in the PCollection to byte arrays, so they // can be transferred over the network for the shuffle. // for readability, we add comments with actual type next to byte[]. // to shorten line length, we use: //---- WV: WindowedValue //---- Iterable: Itr //---- AccumT: A //---- InputT: I JavaPairRDD<ByteArray, byte[]> inRddDuplicatedKeyPairBytes = inRddDuplicatedKeyPair .mapToPair(CoderHelpers.toByteFunction(keyCoder, wkviCoder)); JavaPairRDD</*K*/ ByteArray, /*Itr<WV<KV<K, A>>>*/ byte[]> accumulatedBytes = inRddDuplicatedKeyPairBytes.combineByKey( new Function</*WV<KV<K, I>>*/ byte[], /*Itr<WV<KV<K, A>>>*/ byte[]>() { @Override public /*Itr<WV<KV<K, A>>>*/ byte[] call(/*WV<KV<K, I>>*/ byte[] input) { WindowedValue<KV<K, InputT>> wkvi = CoderHelpers.fromByteArray(input, wkviCoder); return CoderHelpers.toByteArray(sparkCombineFn.createCombiner(wkvi), iterAccumCoder); } }, new Function2</*Itr<WV<KV<K, A>>>*/ byte[], /*WV<KV<K, I>>*/ byte[], /*Itr<WV<KV<K, A>>>*/ byte[]>() { @Override public /*Itr<WV<KV<K, A>>>*/ byte[] call( /*Itr<WV<KV<K, A>>>*/ byte[] acc, /*WV<KV<K, I>>*/ byte[] input) { Iterable<WindowedValue<KV<K, AccumT>>> wkvas = CoderHelpers.fromByteArray(acc, iterAccumCoder); WindowedValue<KV<K, InputT>> wkvi = CoderHelpers.fromByteArray(input, wkviCoder); return CoderHelpers.toByteArray(sparkCombineFn.mergeValue(wkvi, wkvas), iterAccumCoder); } }, new Function2</*Itr<WV<KV<K, A>>>*/ byte[], /*Itr<WV<KV<K, A>>>*/ byte[], /*Itr<WV<KV<K, A>>>*/ byte[]>() { @Override public /*Itr<WV<KV<K, A>>>*/ byte[] call( /*Itr<WV<KV<K, A>>>*/ byte[] acc1, /*Itr<WV<KV<K, A>>>*/ byte[] acc2) { Iterable<WindowedValue<KV<K, AccumT>>> wkvas1 = CoderHelpers.fromByteArray(acc1, iterAccumCoder); Iterable<WindowedValue<KV<K, AccumT>>> wkvas2 = CoderHelpers.fromByteArray(acc2, iterAccumCoder); return CoderHelpers.toByteArray(sparkCombineFn.mergeCombiners(wkvas1, wkvas2), iterAccumCoder); } }); return accumulatedBytes.mapToPair(CoderHelpers.fromByteFunction(keyCoder, iterAccumCoder)); } /** * An implementation of * {@link Reshuffle} for the Spark runner. */ public static <K, V> JavaRDD<WindowedValue<KV<K, V>>> reshuffle( JavaRDD<WindowedValue<KV<K, V>>> rdd, Coder<K> keyCoder, WindowedValueCoder<V> wvCoder) { // Use coders to convert objects in the PCollection to byte arrays, so they // can be transferred over the network for the shuffle. return rdd .map(new ReifyTimestampsAndWindowsFunction<K, V>()) .map(WindowingHelpers.<KV<K, WindowedValue<V>>>unwindowFunction()) .mapToPair(TranslationUtils.<K, WindowedValue<V>>toPairFunction()) .mapToPair(CoderHelpers.toByteFunction(keyCoder, wvCoder)) .repartition(rdd.getNumPartitions()) .mapToPair(CoderHelpers.fromByteFunction(keyCoder, wvCoder)) .map(TranslationUtils.<K, WindowedValue<V>>fromPairFunction()) .map(TranslationUtils.<K, V>toKVByWindowInValue()); } }