/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.runners.spark.translation.streaming; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectSplittable; import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.collect.Maps; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Queue; import java.util.concurrent.LinkedBlockingQueue; import javax.annotation.Nonnull; import org.apache.beam.runners.core.metrics.MetricsContainerStepMap; import org.apache.beam.runners.spark.aggregators.AggregatorsAccumulator; import org.apache.beam.runners.spark.aggregators.NamedAggregators; import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.io.ConsoleIO; import org.apache.beam.runners.spark.io.CreateStream; import org.apache.beam.runners.spark.io.SparkUnboundedSource; import org.apache.beam.runners.spark.metrics.MetricsAccumulator; import org.apache.beam.runners.spark.stateful.SparkGroupAlsoByWindowViaWindowSet; import org.apache.beam.runners.spark.translation.BoundedDataset; import org.apache.beam.runners.spark.translation.Dataset; import org.apache.beam.runners.spark.translation.EvaluationContext; import org.apache.beam.runners.spark.translation.GroupCombineFunctions; import org.apache.beam.runners.spark.translation.MultiDoFnFunction; import org.apache.beam.runners.spark.translation.SparkAssignWindowFn; import org.apache.beam.runners.spark.translation.SparkKeyedCombineFn; import org.apache.beam.runners.spark.translation.SparkPCollectionView; import org.apache.beam.runners.spark.translation.SparkPipelineTranslator; import org.apache.beam.runners.spark.translation.SparkRuntimeContext; import org.apache.beam.runners.spark.translation.TransformEvaluator; import org.apache.beam.runners.spark.translation.TranslationUtils; import org.apache.beam.runners.spark.translation.WindowingHelpers; import org.apache.beam.runners.spark.util.GlobalWatermarkHolder; import org.apache.beam.runners.spark.util.SideInputBroadcast; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.CombineWithContext; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Reshuffle; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.util.CombineFnUtil; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.spark.Accumulator; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaInputDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; /** * Supports translation between a Beam transform, and Spark's operations on DStreams. */ public final class StreamingTransformTranslator { private StreamingTransformTranslator() { } private static <T> TransformEvaluator<ConsoleIO.Write.Unbound<T>> print() { return new TransformEvaluator<ConsoleIO.Write.Unbound<T>>() { @Override public void evaluate(ConsoleIO.Write.Unbound<T> transform, EvaluationContext context) { @SuppressWarnings("unchecked") JavaDStream<WindowedValue<T>> dstream = ((UnboundedDataset<T>) (context).borrowDataset(transform)).getDStream(); dstream.map(WindowingHelpers.<T>unwindowFunction()).print(transform.getNum()); } @Override public String toNativeString() { return ".print(...)"; } }; } private static <T> TransformEvaluator<Read.Unbounded<T>> readUnbounded() { return new TransformEvaluator<Read.Unbounded<T>>() { @Override public void evaluate(Read.Unbounded<T> transform, EvaluationContext context) { final String stepName = context.getCurrentTransform().getFullName(); context.putDataset( transform, SparkUnboundedSource.read( context.getStreamingContext(), context.getRuntimeContext(), transform.getSource(), stepName)); } @Override public String toNativeString() { return "streamingContext.<readFrom(<source>)>()"; } }; } private static <T> TransformEvaluator<CreateStream<T>> createFromQueue() { return new TransformEvaluator<CreateStream<T>>() { @Override public void evaluate(CreateStream<T> transform, EvaluationContext context) { Coder<T> coder = context.getOutput(transform).getCoder(); JavaStreamingContext jssc = context.getStreamingContext(); Queue<Iterable<TimestampedValue<T>>> values = transform.getBatches(); WindowedValue.FullWindowedValueCoder<T> windowCoder = WindowedValue.FullWindowedValueCoder.of(coder, GlobalWindow.Coder.INSTANCE); // create the DStream from queue. Queue<JavaRDD<WindowedValue<T>>> rddQueue = new LinkedBlockingQueue<>(); for (Iterable<TimestampedValue<T>> tv : values) { Iterable<WindowedValue<T>> windowedValues = Iterables.transform( tv, new com.google.common.base.Function<TimestampedValue<T>, WindowedValue<T>>() { @Override public WindowedValue<T> apply(@Nonnull TimestampedValue<T> timestampedValue) { return WindowedValue.of( timestampedValue.getValue(), timestampedValue.getTimestamp(), GlobalWindow.INSTANCE, PaneInfo.NO_FIRING); } }); JavaRDD<WindowedValue<T>> rdd = jssc.sparkContext() .parallelize(CoderHelpers.toByteArrays(windowedValues, windowCoder)) .map(CoderHelpers.fromByteFunction(windowCoder)); rddQueue.offer(rdd); } JavaInputDStream<WindowedValue<T>> inputDStream = jssc.queueStream(rddQueue, true); UnboundedDataset<T> unboundedDataset = new UnboundedDataset<T>( inputDStream, Collections.singletonList(inputDStream.inputDStream().id())); // add pre-baked Watermarks for the pre-baked batches. Queue<GlobalWatermarkHolder.SparkWatermarks> times = transform.getTimes(); GlobalWatermarkHolder.addAll( ImmutableMap.of(unboundedDataset.getStreamSources().get(0), times)); context.putDataset(transform, unboundedDataset); } @Override public String toNativeString() { return "streamingContext.queueStream(...)"; } }; } private static <T> TransformEvaluator<Flatten.PCollections<T>> flattenPColl() { return new TransformEvaluator<Flatten.PCollections<T>>() { @SuppressWarnings("unchecked") @Override public void evaluate(Flatten.PCollections<T> transform, EvaluationContext context) { Map<TupleTag<?>, PValue> pcs = context.getInputs(transform); // since this is a streaming pipeline, at least one of the PCollections to "flatten" are // unbounded, meaning it represents a DStream. // So we could end up with an unbounded unified DStream. final List<JavaDStream<WindowedValue<T>>> dStreams = new ArrayList<>(); final List<Integer> streamingSources = new ArrayList<>(); for (PValue pv : pcs.values()) { checkArgument( pv instanceof PCollection, "Flatten had non-PCollection value in input: %s of type %s", pv, pv.getClass().getSimpleName()); PCollection<T> pcol = (PCollection<T>) pv; Dataset dataset = context.borrowDataset(pcol); if (dataset instanceof UnboundedDataset) { UnboundedDataset<T> unboundedDataset = (UnboundedDataset<T>) dataset; streamingSources.addAll(unboundedDataset.getStreamSources()); dStreams.add(unboundedDataset.getDStream()); } else { // create a single RDD stream. Queue<JavaRDD<WindowedValue<T>>> q = new LinkedBlockingQueue<>(); q.offer(((BoundedDataset) dataset).getRDD()); //TODO: this is not recoverable from checkpoint! JavaDStream<WindowedValue<T>> dStream = context.getStreamingContext().queueStream(q); dStreams.add(dStream); } } // start by unifying streams into a single stream. JavaDStream<WindowedValue<T>> unifiedStreams = context.getStreamingContext().union(dStreams.remove(0), dStreams); context.putDataset(transform, new UnboundedDataset<>(unifiedStreams, streamingSources)); } @Override public String toNativeString() { return "streamingContext.union(...)"; } }; } private static <T, W extends BoundedWindow> TransformEvaluator<Window.Assign<T>> window() { return new TransformEvaluator<Window.Assign<T>>() { @Override public void evaluate(final Window.Assign<T> transform, EvaluationContext context) { @SuppressWarnings("unchecked") UnboundedDataset<T> unboundedDataset = ((UnboundedDataset<T>) context.borrowDataset(transform)); JavaDStream<WindowedValue<T>> dStream = unboundedDataset.getDStream(); JavaDStream<WindowedValue<T>> outputStream; if (TranslationUtils.skipAssignWindows(transform, context)) { // do nothing. outputStream = dStream; } else { outputStream = dStream.transform( new Function<JavaRDD<WindowedValue<T>>, JavaRDD<WindowedValue<T>>>() { @Override public JavaRDD<WindowedValue<T>> call(JavaRDD<WindowedValue<T>> rdd) throws Exception { return rdd.map(new SparkAssignWindowFn<>(transform.getWindowFn())); } }); } context.putDataset(transform, new UnboundedDataset<>(outputStream, unboundedDataset.getStreamSources())); } @Override public String toNativeString() { return "map(new <windowFn>())"; } }; } private static <K, V, W extends BoundedWindow> TransformEvaluator<GroupByKey<K, V>> groupByKey() { return new TransformEvaluator<GroupByKey<K, V>>() { @Override public void evaluate(GroupByKey<K, V> transform, EvaluationContext context) { @SuppressWarnings("unchecked") UnboundedDataset<KV<K, V>> inputDataset = (UnboundedDataset<KV<K, V>>) context.borrowDataset(transform); List<Integer> streamSources = inputDataset.getStreamSources(); JavaDStream<WindowedValue<KV<K, V>>> dStream = inputDataset.getDStream(); @SuppressWarnings("unchecked") final KvCoder<K, V> coder = (KvCoder<K, V>) context.getInput(transform).getCoder(); final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); @SuppressWarnings("unchecked") final WindowingStrategy<?, W> windowingStrategy = (WindowingStrategy<?, W>) context.getInput(transform).getWindowingStrategy(); @SuppressWarnings("unchecked") final WindowFn<Object, W> windowFn = (WindowFn<Object, W>) windowingStrategy.getWindowFn(); //--- coders. final WindowedValue.WindowedValueCoder<V> wvCoder = WindowedValue.FullWindowedValueCoder.of(coder.getValueCoder(), windowFn.windowCoder()); //--- group by key only. JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> groupedByKeyStream = dStream.transform(new Function<JavaRDD<WindowedValue<KV<K, V>>>, JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>>>() { @Override public JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> call( JavaRDD<WindowedValue<KV<K, V>>> rdd) throws Exception { return GroupCombineFunctions.groupByKeyOnly( rdd, coder.getKeyCoder(), wvCoder); } }); //--- now group also by window. JavaDStream<WindowedValue<KV<K, Iterable<V>>>> outStream = SparkGroupAlsoByWindowViaWindowSet.groupAlsoByWindow( groupedByKeyStream, coder.getKeyCoder(), wvCoder, windowingStrategy, runtimeContext, streamSources); context.putDataset(transform, new UnboundedDataset<>(outStream, streamSources)); } @Override public String toNativeString() { return "groupByKey()"; } }; } private static <K, InputT, OutputT> TransformEvaluator<Combine.GroupedValues<K, InputT, OutputT>> combineGrouped() { return new TransformEvaluator<Combine.GroupedValues<K, InputT, OutputT>>() { @Override public void evaluate(final Combine.GroupedValues<K, InputT, OutputT> transform, EvaluationContext context) { // get the applied combine function. PCollection<? extends KV<K, ? extends Iterable<InputT>>> input = context.getInput(transform); final WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy(); @SuppressWarnings("unchecked") final CombineWithContext.CombineFnWithContext<InputT, ?, OutputT> fn = (CombineWithContext.CombineFnWithContext<InputT, ?, OutputT>) CombineFnUtil.toFnWithContext(transform.getFn()); @SuppressWarnings("unchecked") UnboundedDataset<KV<K, Iterable<InputT>>> unboundedDataset = ((UnboundedDataset<KV<K, Iterable<InputT>>>) context.borrowDataset(transform)); JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>> dStream = unboundedDataset.getDStream(); final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); final SparkPCollectionView pviews = context.getPViews(); JavaDStream<WindowedValue<KV<K, OutputT>>> outStream = dStream.transform( new Function<JavaRDD<WindowedValue<KV<K, Iterable<InputT>>>>, JavaRDD<WindowedValue<KV<K, OutputT>>>>() { @Override public JavaRDD<WindowedValue<KV<K, OutputT>>> call(JavaRDD<WindowedValue<KV<K, Iterable<InputT>>>> rdd) throws Exception { SparkKeyedCombineFn<K, InputT, ?, OutputT> combineFnWithContext = new SparkKeyedCombineFn<>(fn, runtimeContext, TranslationUtils.getSideInputs(transform.getSideInputs(), new JavaSparkContext(rdd.context()), pviews), windowingStrategy); return rdd.map( new TranslationUtils.CombineGroupedValues<>(combineFnWithContext)); } }); context.putDataset(transform, new UnboundedDataset<>(outStream, unboundedDataset.getStreamSources())); } @Override public String toNativeString() { return "map(new <fn>())"; } }; } private static <InputT, OutputT> TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>> parDo() { return new TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>>() { public void evaluate( final ParDo.MultiOutput<InputT, OutputT> transform, final EvaluationContext context) { final DoFn<InputT, OutputT> doFn = transform.getFn(); rejectSplittable(doFn); rejectStateAndTimers(doFn); final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); final SparkPCollectionView pviews = context.getPViews(); final WindowingStrategy<?, ?> windowingStrategy = context.getInput(transform).getWindowingStrategy(); @SuppressWarnings("unchecked") UnboundedDataset<InputT> unboundedDataset = ((UnboundedDataset<InputT>) context.borrowDataset(transform)); JavaDStream<WindowedValue<InputT>> dStream = unboundedDataset.getDStream(); final String stepName = context.getCurrentTransform().getFullName(); JavaPairDStream<TupleTag<?>, WindowedValue<?>> all = dStream.transformToPair( new Function< JavaRDD<WindowedValue<InputT>>, JavaPairRDD<TupleTag<?>, WindowedValue<?>>>() { @Override public JavaPairRDD<TupleTag<?>, WindowedValue<?>> call( JavaRDD<WindowedValue<InputT>> rdd) throws Exception { final Accumulator<NamedAggregators> aggAccum = AggregatorsAccumulator.getInstance(); final Accumulator<MetricsContainerStepMap> metricsAccum = MetricsAccumulator.getInstance(); final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs = TranslationUtils.getSideInputs( transform.getSideInputs(), JavaSparkContext.fromSparkContext(rdd.context()), pviews); return rdd.mapPartitionsToPair( new MultiDoFnFunction<>( aggAccum, metricsAccum, stepName, doFn, runtimeContext, transform.getMainOutputTag(), sideInputs, windowingStrategy)); } }); Map<TupleTag<?>, PValue> outputs = context.getOutputs(transform); if (outputs.size() > 1) { // cache the DStream if we're going to filter it more than once. all.cache(); } for (Map.Entry<TupleTag<?>, PValue> output : outputs.entrySet()) { @SuppressWarnings("unchecked") JavaPairDStream<TupleTag<?>, WindowedValue<?>> filtered = all.filter(new TranslationUtils.TupleTagFilter(output.getKey())); @SuppressWarnings("unchecked") // Object is the best we can do since different outputs can have different tags JavaDStream<WindowedValue<Object>> values = (JavaDStream<WindowedValue<Object>>) (JavaDStream<?>) TranslationUtils.dStreamValues(filtered); context.putDataset( output.getValue(), new UnboundedDataset<>(values, unboundedDataset.getStreamSources())); } } @Override public String toNativeString() { return "mapPartitions(new <fn>())"; } }; } private static <K, V, W extends BoundedWindow> TransformEvaluator<Reshuffle<K, V>> reshuffle() { return new TransformEvaluator<Reshuffle<K, V>>() { @Override public void evaluate(Reshuffle<K, V> transform, EvaluationContext context) { @SuppressWarnings("unchecked") UnboundedDataset<KV<K, V>> inputDataset = (UnboundedDataset<KV<K, V>>) context.borrowDataset(transform); List<Integer> streamSources = inputDataset.getStreamSources(); JavaDStream<WindowedValue<KV<K, V>>> dStream = inputDataset.getDStream(); @SuppressWarnings("unchecked") final KvCoder<K, V> coder = (KvCoder<K, V>) context.getInput(transform).getCoder(); @SuppressWarnings("unchecked") final WindowingStrategy<?, W> windowingStrategy = (WindowingStrategy<?, W>) context.getInput(transform).getWindowingStrategy(); @SuppressWarnings("unchecked") final WindowFn<Object, W> windowFn = (WindowFn<Object, W>) windowingStrategy.getWindowFn(); final WindowedValue.WindowedValueCoder<V> wvCoder = WindowedValue.FullWindowedValueCoder.of(coder.getValueCoder(), windowFn.windowCoder()); JavaDStream<WindowedValue<KV<K, V>>> reshuffledStream = dStream.transform(new Function<JavaRDD<WindowedValue<KV<K, V>>>, JavaRDD<WindowedValue<KV<K, V>>>>() { @Override public JavaRDD<WindowedValue<KV<K, V>>> call( JavaRDD<WindowedValue<KV<K, V>>> rdd) throws Exception { return GroupCombineFunctions.reshuffle(rdd, coder.getKeyCoder(), wvCoder); } }); context.putDataset(transform, new UnboundedDataset<>(reshuffledStream, streamSources)); } @Override public String toNativeString() { return "repartition(...)"; } }; } private static final Map<Class<? extends PTransform>, TransformEvaluator<?>> EVALUATORS = Maps.newHashMap(); static { EVALUATORS.put(Read.Unbounded.class, readUnbounded()); EVALUATORS.put(GroupByKey.class, groupByKey()); EVALUATORS.put(Combine.GroupedValues.class, combineGrouped()); EVALUATORS.put(ParDo.MultiOutput.class, parDo()); EVALUATORS.put(ConsoleIO.Write.Unbound.class, print()); EVALUATORS.put(CreateStream.class, createFromQueue()); EVALUATORS.put(Window.Assign.class, window()); EVALUATORS.put(Flatten.PCollections.class, flattenPColl()); EVALUATORS.put(Reshuffle.class, reshuffle()); } /** * Translator matches Beam transformation with the appropriate evaluator. */ public static class Translator implements SparkPipelineTranslator { private final SparkPipelineTranslator batchTranslator; public Translator(SparkPipelineTranslator batchTranslator) { this.batchTranslator = batchTranslator; } @Override public boolean hasTranslation(Class<? extends PTransform<?, ?>> clazz) { // streaming includes rdd/bounded transformations as well return EVALUATORS.containsKey(clazz); } @Override public <TransformT extends PTransform<?, ?>> TransformEvaluator<TransformT> translateBounded(Class<TransformT> clazz) { TransformEvaluator<TransformT> transformEvaluator = batchTranslator.translateBounded(clazz); checkState(transformEvaluator != null, "No TransformEvaluator registered for BOUNDED transform %s", clazz); return transformEvaluator; } @Override public <TransformT extends PTransform<?, ?>> TransformEvaluator<TransformT> translateUnbounded(Class<TransformT> clazz) { @SuppressWarnings("unchecked") TransformEvaluator<TransformT> transformEvaluator = (TransformEvaluator<TransformT>) EVALUATORS.get(clazz); checkState(transformEvaluator != null, "No TransformEvaluator registered for UNBOUNDED transform %s", clazz); return transformEvaluator; } } }